aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Miek Gieben <miek@miek.nl> 2017-06-26 07:44:25 -0700
committerGravatar GitHub <noreply@github.com> 2017-06-26 07:44:25 -0700
commitda5880a273311c90b4abfd16f07ac2d2e4b4a5c7 (patch)
treede204b2349238189330741c96fa6546b0d9c4ebe
parentea90702bfc8a8589a15213bd5dcf58b2a9af758b (diff)
downloadcoredns-da5880a273311c90b4abfd16f07ac2d2e4b4a5c7.tar.gz
coredns-da5880a273311c90b4abfd16f07ac2d2e4b4a5c7.tar.zst
coredns-da5880a273311c90b4abfd16f07ac2d2e4b4a5c7.zip
middleware/cache: fix race (#757)
While adding a parallel performance benchmark I stumbled on a race condition (another reason to add performance benchmarks!), so this PR makes sure the msg is created in a race free manor and adds the parallel benchmark.
-rw-r--r--middleware/cache/cache.go1
-rw-r--r--middleware/cache/cache_test.go44
-rw-r--r--middleware/cache/handler.go1
-rw-r--r--middleware/cache/item.go44
4 files changed, 65 insertions, 25 deletions
diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go
index 30775c598..434efa296 100644
--- a/middleware/cache/cache.go
+++ b/middleware/cache/cache.go
@@ -113,7 +113,6 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
cacheSize.WithLabelValues(Denial).Set(float64(w.ncache.Len()))
}
- setMsgTTL(res, uint32(duration.Seconds()))
if w.prefetch {
return nil
}
diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go
index adac7d67b..f364e69f1 100644
--- a/middleware/cache/cache_test.go
+++ b/middleware/cache/cache_test.go
@@ -6,6 +6,8 @@ import (
"testing"
"time"
+ "golang.org/x/net/context"
+
"github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/pkg/response"
@@ -205,3 +207,45 @@ func TestCache(t *testing.T) {
}
}
}
+
+func BenchmarkCacheResponse(b *testing.B) {
+ c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxTTL}
+ c.pcache = cache.New(c.pcap)
+ c.ncache = cache.New(c.ncap)
+ c.prefetch = 1
+ c.duration = 1 * time.Second
+ c.Next = BackendHandler()
+
+ ctx := context.TODO()
+
+ reqs := make([]*dns.Msg, 5)
+ for i, q := range []string{"example1", "example2", "a", "b", "ddd"} {
+ reqs[i] = new(dns.Msg)
+ reqs[i].SetQuestion(q+".example.org.", dns.TypeA)
+ }
+
+ b.RunParallel(func(pb *testing.PB) {
+ i := 0
+ for pb.Next() {
+ req := reqs[i]
+ c.ServeDNS(ctx, &test.ResponseWriter{}, req)
+ i++
+ i = i % 5
+ }
+ })
+}
+
+func BackendHandler() middleware.Handler {
+ return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ m := new(dns.Msg)
+ m.SetReply(r)
+ m.Response = true
+ m.RecursionAvailable = true
+
+ owner := m.Question[0].Name
+ m.Answer = []dns.RR{test.A(owner + " 303 IN A 127.0.0.53")}
+
+ w.WriteMsg(m)
+ return dns.RcodeSuccess, nil
+ })
+}
diff --git a/middleware/cache/handler.go b/middleware/cache/handler.go
index 520b23767..ce3df2f75 100644
--- a/middleware/cache/handler.go
+++ b/middleware/cache/handler.go
@@ -29,6 +29,7 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
i, ttl := c.get(now, qname, qtype, do)
if i != nil && ttl > 0 {
resp := i.toMsg(r)
+
state.SizeAndDo(resp)
resp, _ = state.Scrub(resp)
w.WriteMsg(resp)
diff --git a/middleware/cache/item.go b/middleware/cache/item.go
index 5084bcf1c..02571ac5c 100644
--- a/middleware/cache/item.go
+++ b/middleware/cache/item.go
@@ -63,39 +63,35 @@ func (i *item) toMsg(m *dns.Msg) *dns.Msg {
m1.Rcode = i.Rcode
m1.Compress = true
- m1.Answer = i.Answer
- m1.Ns = i.Ns
- m1.Extra = i.Extra
+ m1.Answer = make([]dns.RR, len(i.Answer))
+ m1.Ns = make([]dns.RR, len(i.Ns))
+ m1.Extra = make([]dns.RR, len(i.Extra))
- ttl := int(i.origTTL) - int(time.Now().UTC().Sub(i.stored).Seconds())
- setMsgTTL(m1, uint32(ttl))
- return m1
-}
-
-func (i *item) ttl(now time.Time) int {
- ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
- return ttl
-}
-
-// setMsgTTL sets the ttl on all RRs in all sections. If ttl is smaller than minTTL
-// that value is used.
-func setMsgTTL(m *dns.Msg, ttl uint32) {
+ ttl := uint32(i.ttl(time.Now()))
if ttl < minTTL {
ttl = minTTL
}
- for _, r := range m.Answer {
- r.Header().Ttl = ttl
+ for j, r := range i.Answer {
+ m1.Answer[j] = dns.Copy(r)
+ m1.Answer[j].Header().Ttl = ttl
}
- for _, r := range m.Ns {
- r.Header().Ttl = ttl
+ for j, r := range i.Ns {
+ m1.Ns[j] = dns.Copy(r)
+ m1.Ns[j].Header().Ttl = ttl
}
- for _, r := range m.Extra {
- if r.Header().Rrtype == dns.TypeOPT {
- continue
+ for j, r := range i.Extra {
+ m1.Extra[j] = dns.Copy(r)
+ if m1.Extra[j].Header().Rrtype != dns.TypeOPT {
+ m1.Extra[j].Header().Ttl = ttl
}
- r.Header().Ttl = ttl
}
+ return m1
+}
+
+func (i *item) ttl(now time.Time) int {
+ ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
+ return ttl
}
func minMsgTTL(m *dns.Msg, mt response.Type) time.Duration {