diff options
-rw-r--r-- | plugin/cache/cache.go | 36 | ||||
-rw-r--r-- | plugin/cache/handler.go | 12 |
2 files changed, 40 insertions, 8 deletions
diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go index 4e64fa733..ed39fee86 100644 --- a/plugin/cache/cache.go +++ b/plugin/cache/cache.go @@ -4,6 +4,7 @@ package cache import ( "encoding/binary" "hash/fnv" + "net" "time" "github.com/coredns/coredns/plugin" @@ -105,7 +106,40 @@ type ResponseWriter struct { state request.Request server string // Server handling the request. - prefetch bool // When true write nothing back to the client. + prefetch bool // When true write nothing back to the client. + remoteAddr net.Addr +} + +// newPrefetchResponseWriter returns a Cache ResponseWriter to be used in +// prefetch requests. It ensures RemoteAddr() can be called even after the +// original connetion has already been closed. +func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter { + // Resolve the address now, the connection might be already closed when the + // actual prefetch request is made. + addr := state.W.RemoteAddr() + // The protocol of the client triggering a cache prefetch doesn't matter. + // The address type is used by request.Proto to determine the response size, + // and using TCP ensures the message isn't unnecessarily truncated. + if u, ok := addr.(*net.UDPAddr); ok { + addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone} + } + + return &ResponseWriter{ + ResponseWriter: state.W, + Cache: c, + state: state, + server: server, + prefetch: true, + remoteAddr: addr, + } +} + +// RemoteAddr implements the dns.ResponseWriter interface. +func (w *ResponseWriter) RemoteAddr() net.Addr { + if w.remoteAddr != nil { + return w.remoteAddr + } + return w.ResponseWriter.RemoteAddr() } // WriteMsg implements the dns.ResponseWriter interface. diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go index 598640568..bb5898934 100644 --- a/plugin/cache/handler.go +++ b/plugin/cache/handler.go @@ -40,20 +40,18 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL))) if i.Freq.Hits() >= c.prefetch && ttl <= threshold { - go func() { + cw := newPrefetchResponseWriter(server, state, c) + go func(w dns.ResponseWriter) { cachePrefetches.WithLabelValues(server).Inc() + plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r) + // When prefetching we loose the item i, and with it the frequency // that we've gathered sofar. See we copy the frequencies info back // into the new item that was stored in the cache. - prr := &ResponseWriter{ResponseWriter: w, Cache: c, - prefetch: true, state: state, - server: server} - plugin.NextOrFailure(c.Name(), c.Next, ctx, prr, r) - if i1 := c.exists(state); i1 != nil { i1.Freq.Reset(now, i.Freq.Hits()) } - }() + }(cw) } } return dns.RcodeSuccess, nil |