aboutsummaryrefslogtreecommitdiff
path: root/middleware/cache/cache.go
diff options
context:
space:
mode:
Diffstat (limited to 'middleware/cache/cache.go')
-rw-r--r--middleware/cache/cache.go196
1 files changed, 196 insertions, 0 deletions
diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go
new file mode 100644
index 000000000..1ec71b047
--- /dev/null
+++ b/middleware/cache/cache.go
@@ -0,0 +1,196 @@
+package cache
+
+/*
+The idea behind this implementation is as follows. We have a cache that is index
+by a couple different keys, which allows use to have:
+
+- negative cache: qname only for NXDOMAIN responses
+- negative cache: qname + qtype for NODATA responses
+- positive cache: qname + qtype for succesful responses.
+
+We track DNSSEC responses separately, i.e. under a different cache key.
+Each Item stored contains the message split up in the different sections
+and a few bits of the msg header.
+
+For instance an NXDOMAIN for blaat.miek.nl will create the
+following negative cache entry (do signal state of DO (do off, DO on)).
+
+ ncache: do <blaat.miek.nl>
+ Item:
+ Ns: <miek.nl> SOA RR
+
+If found a return packet is assembled and returned to the client. Taking size and EDNS0
+constraints into account.
+
+We also need to track if the answer received was an authoritative answer, ad bit and other
+setting, for this we also store a few header bits.
+
+For the positive cache we use the same idea. Truncated responses are never stored.
+*/
+
+import (
+ "log"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+
+ "github.com/miekg/dns"
+ gcache "github.com/patrickmn/go-cache"
+)
+
+// Cache is middleware that looks up responses in a cache and caches replies.
+type Cache struct {
+ Next middleware.Handler
+ Zones []string
+ cache *gcache.Cache
+ cap time.Duration
+}
+
+func NewCache(ttl int, zones []string, next middleware.Handler) Cache {
+ return Cache{Next: next, Zones: zones, cache: gcache.New(defaultDuration, purgeDuration), cap: time.Duration(ttl) * time.Second}
+}
+
+type messageType int
+
+const (
+ success messageType = iota
+ nameError // NXDOMAIN in header, SOA in auth.
+ noData // NOERROR in header, SOA in auth.
+ otherError // Don't cache these.
+)
+
+// classify classifies a message, it returns the MessageType.
+func classify(m *dns.Msg) (messageType, *dns.OPT) {
+ opt := m.IsEdns0()
+ soa := false
+ if m.Rcode == dns.RcodeSuccess {
+ return success, opt
+ }
+ for _, r := range m.Ns {
+ if r.Header().Rrtype == dns.TypeSOA {
+ soa = true
+ break
+ }
+ }
+
+ // Check length of different section, and drop stuff that is just to large.
+ if soa && m.Rcode == dns.RcodeSuccess {
+ return noData, opt
+ }
+ if soa && m.Rcode == dns.RcodeNameError {
+ return nameError, opt
+ }
+
+ return otherError, opt
+}
+
+func cacheKey(m *dns.Msg, t messageType, do bool) string {
+ if m.Truncated {
+ return ""
+ }
+
+ qtype := m.Question[0].Qtype
+ qname := middleware.Name(m.Question[0].Name).Normalize()
+ switch t {
+ case success:
+ return successKey(qname, qtype, do)
+ case nameError:
+ return nameErrorKey(qname, do)
+ case noData:
+ return noDataKey(qname, qtype, do)
+ case otherError:
+ return ""
+ }
+ return ""
+}
+
+type CachingResponseWriter struct {
+ dns.ResponseWriter
+ cache *gcache.Cache
+ cap time.Duration
+}
+
+func NewCachingResponseWriter(w dns.ResponseWriter, cache *gcache.Cache, cap time.Duration) *CachingResponseWriter {
+ return &CachingResponseWriter{w, cache, cap}
+}
+
+func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error {
+ do := false
+ mt, opt := classify(res)
+ if opt != nil {
+ do = opt.Do()
+ }
+
+ key := cacheKey(res, mt, do)
+ c.Set(res, key, mt)
+
+ if c.cap != 0 {
+ setCap(res, uint32(c.cap.Seconds()))
+ }
+
+ return c.ResponseWriter.WriteMsg(res)
+}
+
+func (c *CachingResponseWriter) Set(m *dns.Msg, key string, mt messageType) {
+ if key == "" {
+ // logger the log? TODO(miek)
+ return
+ }
+
+ duration := c.cap
+ switch mt {
+ case success:
+ if c.cap == 0 {
+ duration = minTtl(m.Answer, mt)
+ }
+ i := newItem(m, duration)
+
+ c.cache.Set(key, i, duration)
+ case nameError, noData:
+ if c.cap == 0 {
+ duration = minTtl(m.Ns, mt)
+ }
+ i := newItem(m, duration)
+
+ c.cache.Set(key, i, duration)
+ }
+}
+
+func (c *CachingResponseWriter) Write(buf []byte) (int, error) {
+ log.Printf("[WARNING] Caching called with Write: not caching reply")
+ n, err := c.ResponseWriter.Write(buf)
+ return n, err
+}
+
+func (c *CachingResponseWriter) Hijack() {
+ c.ResponseWriter.Hijack()
+ return
+}
+
+func minTtl(rrs []dns.RR, mt messageType) time.Duration {
+ if mt != success && mt != nameError && mt != noData {
+ return 0
+ }
+
+ minTtl := maxTtl
+ for _, r := range rrs {
+ switch mt {
+ case nameError, noData:
+ if r.Header().Rrtype == dns.TypeSOA {
+ return time.Duration(r.(*dns.SOA).Minttl) * time.Second
+ }
+ case success:
+ if r.Header().Ttl < minTtl {
+ minTtl = r.Header().Ttl
+ }
+ }
+ }
+ return time.Duration(minTtl) * time.Second
+}
+
+const (
+ purgeDuration = 1 * time.Minute
+ defaultDuration = 20 * time.Minute
+ baseTtl = 5 // minimum ttl that we will allow
+ maxTtl uint32 = 2 * 3600
+)