aboutsummaryrefslogtreecommitdiff
path: root/middleware/cache
diff options
context:
space:
mode:
Diffstat (limited to 'middleware/cache')
-rw-r--r--middleware/cache/README.md29
-rw-r--r--middleware/cache/cache.go196
-rw-r--r--middleware/cache/cache_test.go112
-rw-r--r--middleware/cache/handler.go44
-rw-r--r--middleware/cache/item.go98
-rw-r--r--middleware/cache/item_test.go25
6 files changed, 504 insertions, 0 deletions
diff --git a/middleware/cache/README.md b/middleware/cache/README.md
new file mode 100644
index 000000000..aade84694
--- /dev/null
+++ b/middleware/cache/README.md
@@ -0,0 +1,29 @@
+# cache
+
+`cache` enables a frontend cache.
+
+## Syntax
+
+~~~
+cache [ttl] [zones...]
+~~~
+
+* `ttl` max TTL in seconds, if not specified the TTL of the reply (SOA minimum or minimum TTL in the
+ answer section) will be used.
+* `zones` zones it should should cache for. If empty the zones from the configuration block are used.
+
+
+Each element in the cache is cached according to its TTL, for the negative cache the SOA's MinTTL
+value is used.
+
+A cache mostly makes sense with a middleware that is potentially slow, i.e. a proxy that retrieves
+answer, or to minimize backend queries for middleware like etcd. Using a cache with the file
+middleware essentially doubles the memory load with no concealable increase of query speed.
+
+## Examples
+
+~~~
+cache
+~~~
+
+Enable caching for all zones.
diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go
new file mode 100644
index 000000000..1ec71b047
--- /dev/null
+++ b/middleware/cache/cache.go
@@ -0,0 +1,196 @@
+package cache
+
+/*
+The idea behind this implementation is as follows. We have a cache that is index
+by a couple different keys, which allows use to have:
+
+- negative cache: qname only for NXDOMAIN responses
+- negative cache: qname + qtype for NODATA responses
+- positive cache: qname + qtype for succesful responses.
+
+We track DNSSEC responses separately, i.e. under a different cache key.
+Each Item stored contains the message split up in the different sections
+and a few bits of the msg header.
+
+For instance an NXDOMAIN for blaat.miek.nl will create the
+following negative cache entry (do signal state of DO (do off, DO on)).
+
+ ncache: do <blaat.miek.nl>
+ Item:
+ Ns: <miek.nl> SOA RR
+
+If found a return packet is assembled and returned to the client. Taking size and EDNS0
+constraints into account.
+
+We also need to track if the answer received was an authoritative answer, ad bit and other
+setting, for this we also store a few header bits.
+
+For the positive cache we use the same idea. Truncated responses are never stored.
+*/
+
+import (
+ "log"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+
+ "github.com/miekg/dns"
+ gcache "github.com/patrickmn/go-cache"
+)
+
+// Cache is middleware that looks up responses in a cache and caches replies.
+type Cache struct {
+ Next middleware.Handler
+ Zones []string
+ cache *gcache.Cache
+ cap time.Duration
+}
+
+func NewCache(ttl int, zones []string, next middleware.Handler) Cache {
+ return Cache{Next: next, Zones: zones, cache: gcache.New(defaultDuration, purgeDuration), cap: time.Duration(ttl) * time.Second}
+}
+
+type messageType int
+
+const (
+ success messageType = iota
+ nameError // NXDOMAIN in header, SOA in auth.
+ noData // NOERROR in header, SOA in auth.
+ otherError // Don't cache these.
+)
+
+// classify classifies a message, it returns the MessageType.
+func classify(m *dns.Msg) (messageType, *dns.OPT) {
+ opt := m.IsEdns0()
+ soa := false
+ if m.Rcode == dns.RcodeSuccess {
+ return success, opt
+ }
+ for _, r := range m.Ns {
+ if r.Header().Rrtype == dns.TypeSOA {
+ soa = true
+ break
+ }
+ }
+
+ // Check length of different section, and drop stuff that is just to large.
+ if soa && m.Rcode == dns.RcodeSuccess {
+ return noData, opt
+ }
+ if soa && m.Rcode == dns.RcodeNameError {
+ return nameError, opt
+ }
+
+ return otherError, opt
+}
+
+func cacheKey(m *dns.Msg, t messageType, do bool) string {
+ if m.Truncated {
+ return ""
+ }
+
+ qtype := m.Question[0].Qtype
+ qname := middleware.Name(m.Question[0].Name).Normalize()
+ switch t {
+ case success:
+ return successKey(qname, qtype, do)
+ case nameError:
+ return nameErrorKey(qname, do)
+ case noData:
+ return noDataKey(qname, qtype, do)
+ case otherError:
+ return ""
+ }
+ return ""
+}
+
+type CachingResponseWriter struct {
+ dns.ResponseWriter
+ cache *gcache.Cache
+ cap time.Duration
+}
+
+func NewCachingResponseWriter(w dns.ResponseWriter, cache *gcache.Cache, cap time.Duration) *CachingResponseWriter {
+ return &CachingResponseWriter{w, cache, cap}
+}
+
+func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error {
+ do := false
+ mt, opt := classify(res)
+ if opt != nil {
+ do = opt.Do()
+ }
+
+ key := cacheKey(res, mt, do)
+ c.Set(res, key, mt)
+
+ if c.cap != 0 {
+ setCap(res, uint32(c.cap.Seconds()))
+ }
+
+ return c.ResponseWriter.WriteMsg(res)
+}
+
+func (c *CachingResponseWriter) Set(m *dns.Msg, key string, mt messageType) {
+ if key == "" {
+ // logger the log? TODO(miek)
+ return
+ }
+
+ duration := c.cap
+ switch mt {
+ case success:
+ if c.cap == 0 {
+ duration = minTtl(m.Answer, mt)
+ }
+ i := newItem(m, duration)
+
+ c.cache.Set(key, i, duration)
+ case nameError, noData:
+ if c.cap == 0 {
+ duration = minTtl(m.Ns, mt)
+ }
+ i := newItem(m, duration)
+
+ c.cache.Set(key, i, duration)
+ }
+}
+
+func (c *CachingResponseWriter) Write(buf []byte) (int, error) {
+ log.Printf("[WARNING] Caching called with Write: not caching reply")
+ n, err := c.ResponseWriter.Write(buf)
+ return n, err
+}
+
+func (c *CachingResponseWriter) Hijack() {
+ c.ResponseWriter.Hijack()
+ return
+}
+
+func minTtl(rrs []dns.RR, mt messageType) time.Duration {
+ if mt != success && mt != nameError && mt != noData {
+ return 0
+ }
+
+ minTtl := maxTtl
+ for _, r := range rrs {
+ switch mt {
+ case nameError, noData:
+ if r.Header().Rrtype == dns.TypeSOA {
+ return time.Duration(r.(*dns.SOA).Minttl) * time.Second
+ }
+ case success:
+ if r.Header().Ttl < minTtl {
+ minTtl = r.Header().Ttl
+ }
+ }
+ }
+ return time.Duration(minTtl) * time.Second
+}
+
+const (
+ purgeDuration = 1 * time.Minute
+ defaultDuration = 20 * time.Minute
+ baseTtl = 5 // minimum ttl that we will allow
+ maxTtl uint32 = 2 * 3600
+)
diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go
new file mode 100644
index 000000000..310a1164e
--- /dev/null
+++ b/middleware/cache/cache_test.go
@@ -0,0 +1,112 @@
+package cache
+
+import (
+ "testing"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/test"
+
+ "github.com/miekg/dns"
+)
+
+type cacheTestCase struct {
+ test.Case
+ in test.Case
+ AuthenticatedData bool
+ Authoritative bool
+ RecursionAvailable bool
+ Truncated bool
+}
+
+var cacheTestCases = []cacheTestCase{
+ {
+ RecursionAvailable: true, AuthenticatedData: true, Authoritative: true,
+ Case: test.Case{
+ Qname: "miek.nl.", Qtype: dns.TypeMX,
+ Answer: []dns.RR{
+ test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com."),
+ test.MX("miek.nl. 1800 IN MX 10 aspmx2.googlemail.com."),
+ test.MX("miek.nl. 1800 IN MX 10 aspmx3.googlemail.com."),
+ test.MX("miek.nl. 1800 IN MX 5 alt1.aspmx.l.google.com."),
+ test.MX("miek.nl. 1800 IN MX 5 alt2.aspmx.l.google.com."),
+ },
+ },
+ in: test.Case{
+ Qname: "miek.nl.", Qtype: dns.TypeMX,
+ Answer: []dns.RR{
+ test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com."),
+ test.MX("miek.nl. 1800 IN MX 10 aspmx2.googlemail.com."),
+ test.MX("miek.nl. 1800 IN MX 10 aspmx3.googlemail.com."),
+ test.MX("miek.nl. 1800 IN MX 5 alt1.aspmx.l.google.com."),
+ test.MX("miek.nl. 1800 IN MX 5 alt2.aspmx.l.google.com."),
+ },
+ },
+ },
+ {
+ Truncated: true,
+ Case: test.Case{
+ Qname: "miek.nl.", Qtype: dns.TypeMX,
+ Answer: []dns.RR{test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com.")},
+ },
+ in: test.Case{},
+ },
+}
+
+func cacheMsg(m *dns.Msg, tc cacheTestCase) *dns.Msg {
+ m.RecursionAvailable = tc.RecursionAvailable
+ m.AuthenticatedData = tc.AuthenticatedData
+ m.Authoritative = tc.Authoritative
+ m.Truncated = tc.Truncated
+ m.Answer = tc.in.Answer
+ m.Ns = tc.in.Ns
+ // m.Extra = tc.in.Extra , not the OPT record!
+ return m
+}
+
+func newTestCache() (Cache, *CachingResponseWriter) {
+ c := NewCache(0, []string{"."}, nil)
+ crr := NewCachingResponseWriter(nil, c.cache, time.Duration(0))
+ return c, crr
+}
+
+func TestCache(t *testing.T) {
+ c, crr := newTestCache()
+
+ for _, tc := range cacheTestCases {
+ m := tc.in.Msg()
+ m = cacheMsg(m, tc)
+ do := tc.in.Do
+
+ mt, _ := classify(m)
+ key := cacheKey(m, mt, do)
+ crr.Set(m, key, mt)
+
+ name := middleware.Name(m.Question[0].Name).Normalize()
+ qtype := m.Question[0].Qtype
+ i, ok := c.Get(name, qtype, do)
+ if !ok && !m.Truncated {
+ t.Errorf("Truncated message should not have been cached")
+ }
+
+ if ok {
+ resp := i.toMsg(m)
+
+ if !test.Header(t, tc.Case, resp) {
+ t.Logf("%v\n", resp)
+ continue
+ }
+
+ if !test.Section(t, tc.Case, test.Answer, resp.Answer) {
+ t.Logf("%v\n", resp)
+ }
+ if !test.Section(t, tc.Case, test.Ns, resp.Ns) {
+ t.Logf("%v\n", resp)
+
+ }
+ if !test.Section(t, tc.Case, test.Extra, resp.Extra) {
+ t.Logf("%v\n", resp)
+ }
+ }
+ }
+}
diff --git a/middleware/cache/handler.go b/middleware/cache/handler.go
new file mode 100644
index 000000000..51e3731bd
--- /dev/null
+++ b/middleware/cache/handler.go
@@ -0,0 +1,44 @@
+package cache
+
+import (
+ "github.com/miekg/coredns/middleware"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+// ServeDNS implements the middleware.Handler interface.
+func (c Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ state := middleware.State{W: w, Req: r}
+
+ qname := state.Name()
+ qtype := state.QType()
+ zone := middleware.Zones(c.Zones).Matches(qname)
+ if zone == "" {
+ return c.Next.ServeDNS(ctx, w, r)
+ }
+
+ do := state.Do() // might need more from OPT record?
+
+ if i, ok := c.Get(qname, qtype, do); ok {
+ resp := i.toMsg(r)
+ state.SizeAndDo(resp)
+ w.WriteMsg(resp)
+ return dns.RcodeSuccess, nil
+ }
+ crr := NewCachingResponseWriter(w, c.cache, c.cap)
+ return c.Next.ServeDNS(ctx, crr, r)
+}
+
+func (c Cache) Get(qname string, qtype uint16, do bool) (*item, bool) {
+ nxdomain := nameErrorKey(qname, do)
+ if i, ok := c.cache.Get(nxdomain); ok {
+ return i.(*item), true
+ }
+
+ successOrNoData := successKey(qname, qtype, do)
+ if i, ok := c.cache.Get(successOrNoData); ok {
+ return i.(*item), true
+ }
+ return nil, false
+}
diff --git a/middleware/cache/item.go b/middleware/cache/item.go
new file mode 100644
index 000000000..6f0190c52
--- /dev/null
+++ b/middleware/cache/item.go
@@ -0,0 +1,98 @@
+package cache
+
+import (
+ "strconv"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+type item struct {
+ Authoritative bool
+ AuthenticatedData bool
+ RecursionAvailable bool
+ Answer []dns.RR
+ Ns []dns.RR
+ Extra []dns.RR
+
+ origTtl uint32
+ stored time.Time
+}
+
+func newItem(m *dns.Msg, d time.Duration) *item {
+ i := new(item)
+ i.Authoritative = m.Authoritative
+ i.AuthenticatedData = m.AuthenticatedData
+ i.RecursionAvailable = m.RecursionAvailable
+ i.Answer = m.Answer
+ i.Ns = m.Ns
+ i.Extra = make([]dns.RR, len(m.Extra))
+ // Don't copy OPT record as these are hop-by-hop.
+ j := 0
+ for _, e := range m.Extra {
+ if e.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ i.Extra[j] = e
+ j++
+ }
+ i.Extra = i.Extra[:j]
+
+ i.origTtl = uint32(d.Seconds())
+ i.stored = time.Now().UTC()
+
+ return i
+}
+
+// toMsg turns i into a message, it tailers to reply to m.
+func (i *item) toMsg(m *dns.Msg) *dns.Msg {
+ m1 := new(dns.Msg)
+ m1.SetReply(m)
+ m1.Authoritative = i.Authoritative
+ m1.AuthenticatedData = i.AuthenticatedData
+ m1.RecursionAvailable = i.RecursionAvailable
+ m1.Compress = true
+
+ m1.Answer = i.Answer
+ m1.Ns = i.Ns
+ m1.Extra = i.Extra
+
+ ttl := int(i.origTtl) - int(time.Now().UTC().Sub(i.stored).Seconds())
+ if ttl < baseTtl {
+ ttl = baseTtl
+ }
+ setCap(m1, uint32(ttl))
+ return m1
+}
+
+// setCap sets the ttl on all RRs in all sections.
+func setCap(m *dns.Msg, ttl uint32) {
+ for _, r := range m.Answer {
+ r.Header().Ttl = uint32(ttl)
+ }
+ for _, r := range m.Ns {
+ r.Header().Ttl = uint32(ttl)
+ }
+ for _, r := range m.Extra {
+ r.Header().Ttl = uint32(ttl)
+ }
+}
+
+// nodataKey returns a caching key for NODATA responses.
+func noDataKey(qname string, qtype uint16, do bool) string {
+ if do {
+ return "1" + qname + ".." + strconv.Itoa(int(qtype))
+ }
+ return "0" + qname + ".." + strconv.Itoa(int(qtype))
+}
+
+// nameErrorKey returns a caching key for NXDOMAIN responses.
+func nameErrorKey(qname string, do bool) string {
+ if do {
+ return "1" + qname
+ }
+ return "0" + qname
+}
+
+// successKey returns a caching key for successfull answers.
+func successKey(qname string, qtype uint16, do bool) string { return noDataKey(qname, qtype, do) }
diff --git a/middleware/cache/item_test.go b/middleware/cache/item_test.go
new file mode 100644
index 000000000..5989b0099
--- /dev/null
+++ b/middleware/cache/item_test.go
@@ -0,0 +1,25 @@
+package cache
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestKey(t *testing.T) {
+ if noDataKey("miek.nl.", dns.TypeMX, false) != "0miek.nl...15" {
+ t.Errorf("failed to create correct key")
+ }
+ if noDataKey("miek.nl.", dns.TypeMX, true) != "1miek.nl...15" {
+ t.Errorf("failed to create correct key")
+ }
+ if nameErrorKey("miek.nl.", false) != "0miek.nl." {
+ t.Errorf("failed to create correct key")
+ }
+ if nameErrorKey("miek.nl.", true) != "1miek.nl." {
+ t.Errorf("failed to create correct key")
+ }
+ if noDataKey("miek.nl.", dns.TypeMX, false) != successKey("miek.nl.", dns.TypeMX, false) {
+ t.Errorf("nameErrorKey and successKey should be the same")
+ }
+}