aboutsummaryrefslogtreecommitdiff
path: root/plugin/cache
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/cache')
-rw-r--r--plugin/cache/cache_test.go12
-rw-r--r--plugin/cache/handler.go27
-rw-r--r--plugin/cache/item.go15
3 files changed, 29 insertions, 25 deletions
diff --git a/plugin/cache/cache_test.go b/plugin/cache/cache_test.go
index 9417a59f6..d839ea1a3 100644
--- a/plugin/cache/cache_test.go
+++ b/plugin/cache/cache_test.go
@@ -191,7 +191,7 @@ func TestCache(t *testing.T) {
c, crr := newTestCache(maxTTL)
- for _, tc := range cacheTestCases {
+ for n, tc := range cacheTestCases {
m := tc.in.Msg()
m = cacheMsg(m, tc)
@@ -204,11 +204,15 @@ func TestCache(t *testing.T) {
crr.set(m, k, mt, c.pttl)
}
- i, _ := c.get(time.Now().UTC(), state, "dns://:53")
+ i := c.getIgnoreTTL(time.Now().UTC(), state, "dns://:53")
ok := i != nil
- if ok != tc.shouldCache {
- t.Errorf("Cached message that should not have been cached: %s", state.Name())
+ if !tc.shouldCache && ok {
+ t.Errorf("Test %d: Cached message that should not have been cached: %s", n, state.Name())
+ continue
+ }
+ if tc.shouldCache && !ok {
+ t.Errorf("Test %d: Did not cache message that should have been cached: %s", n, state.Name())
continue
}
diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go
index b7adc3a9e..2b4c89350 100644
--- a/plugin/cache/handler.go
+++ b/plugin/cache/handler.go
@@ -89,38 +89,23 @@ func (c *Cache) shouldPrefetch(i *item, now time.Time) bool {
// Name implements the Handler interface.
func (c *Cache) Name() string { return "cache" }
-func (c *Cache) get(now time.Time, state request.Request, server string) (*item, bool) {
- k := hash(state.Name(), state.QType())
- cacheRequests.WithLabelValues(server, c.zonesMetricLabel).Inc()
-
- if i, ok := c.ncache.Get(k); ok && i.(*item).ttl(now) > 0 {
- cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel).Inc()
- return i.(*item), true
- }
-
- if i, ok := c.pcache.Get(k); ok && i.(*item).ttl(now) > 0 {
- cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel).Inc()
- return i.(*item), true
- }
- cacheMisses.WithLabelValues(server, c.zonesMetricLabel).Inc()
- return nil, false
-}
-
// getIgnoreTTL unconditionally returns an item if it exists in the cache.
func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item {
k := hash(state.Name(), state.QType())
cacheRequests.WithLabelValues(server, c.zonesMetricLabel).Inc()
if i, ok := c.ncache.Get(k); ok {
- ttl := i.(*item).ttl(now)
- if ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds())) {
+ itm := i.(*item)
+ ttl := itm.ttl(now)
+ if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel).Inc()
return i.(*item)
}
}
if i, ok := c.pcache.Get(k); ok {
- ttl := i.(*item).ttl(now)
- if ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds())) {
+ itm := i.(*item)
+ ttl := itm.ttl(now)
+ if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel).Inc()
return i.(*item)
}
diff --git a/plugin/cache/item.go b/plugin/cache/item.go
index 3b47a3b6b..56d188b36 100644
--- a/plugin/cache/item.go
+++ b/plugin/cache/item.go
@@ -1,14 +1,18 @@
package cache
import (
+ "strings"
"time"
"github.com/coredns/coredns/plugin/cache/freq"
+ "github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type item struct {
+ Name string
+ QType uint16
Rcode int
AuthenticatedData bool
RecursionAvailable bool
@@ -24,6 +28,10 @@ type item struct {
func newItem(m *dns.Msg, now time.Time, d time.Duration) *item {
i := new(item)
+ if len(m.Question) != 0 {
+ i.Name = m.Question[0].Name
+ i.QType = m.Question[0].Qtype
+ }
i.Rcode = m.Rcode
i.AuthenticatedData = m.AuthenticatedData
i.RecursionAvailable = m.RecursionAvailable
@@ -87,3 +95,10 @@ func (i *item) ttl(now time.Time) int {
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
return ttl
}
+
+func (i *item) matches(state request.Request) bool {
+ if state.QType() == i.QType && strings.EqualFold(state.QName(), i.Name) {
+ return true
+ }
+ return false
+}