aboutsummaryrefslogtreecommitdiff
path: root/plugin
diff options
context:
space:
mode:
Diffstat (limited to 'plugin')
-rw-r--r--plugin/backend.go5
-rw-r--r--plugin/etcd/xfr.go7
-rw-r--r--plugin/federation/kubernetes_api_test.go1
-rw-r--r--plugin/kubernetes/controller.go48
-rw-r--r--plugin/kubernetes/handler.go3
-rw-r--r--plugin/kubernetes/handler_test.go11
-rw-r--r--plugin/kubernetes/kubernetes.go1
-rw-r--r--plugin/kubernetes/kubernetes_test.go1
-rw-r--r--plugin/kubernetes/ns_test.go1
-rw-r--r--plugin/kubernetes/reverse_test.go1
-rw-r--r--plugin/kubernetes/setup.go10
-rw-r--r--plugin/kubernetes/xfr.go201
-rw-r--r--plugin/kubernetes/xfr_test.go111
13 files changed, 383 insertions, 18 deletions
diff --git a/plugin/backend.go b/plugin/backend.go
index fad61d418..9abb277f7 100644
--- a/plugin/backend.go
+++ b/plugin/backend.go
@@ -3,6 +3,7 @@ package plugin
import (
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/request"
+ "golang.org/x/net/context"
"github.com/miekg/dns"
)
@@ -37,6 +38,10 @@ type Transferer interface {
// MinTTL returns the minimum TTL to be used in the SOA record.
MinTTL(state request.Request) uint32
+
+ // Transfer handles a zone transfer it writes to the client just
+ // like any other handler.
+ Transfer(ctx context.Context, state request.Request) (int, error)
}
// Options are extra options that can be specified for a lookup.
diff --git a/plugin/etcd/xfr.go b/plugin/etcd/xfr.go
index 43a734cf9..fcb8dda07 100644
--- a/plugin/etcd/xfr.go
+++ b/plugin/etcd/xfr.go
@@ -4,6 +4,8 @@ import (
"time"
"github.com/coredns/coredns/request"
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
)
// Serial implements the Transferer interface.
@@ -15,3 +17,8 @@ func (e *Etcd) Serial(state request.Request) uint32 {
func (e *Etcd) MinTTL(state request.Request) uint32 {
return 30
}
+
+// Transfer implements the Transferer interface.
+func (e *Etcd) Transfer(ctx context.Context, state request.Request) (int, error) {
+ return dns.RcodeServerFailure, nil
+}
diff --git a/plugin/federation/kubernetes_api_test.go b/plugin/federation/kubernetes_api_test.go
index b101246f5..ee4757d22 100644
--- a/plugin/federation/kubernetes_api_test.go
+++ b/plugin/federation/kubernetes_api_test.go
@@ -14,6 +14,7 @@ func (APIConnFederationTest) Run() { return }
func (APIConnFederationTest) Stop() error { return nil }
func (APIConnFederationTest) SvcIndexReverse(string) []*api.Service { return nil }
func (APIConnFederationTest) EpIndexReverse(string) []*api.Endpoints { return nil }
+func (APIConnFederationTest) Modified() int64 { return 0 }
func (APIConnFederationTest) PodIndex(string) []*api.Pod {
a := []*api.Pod{{
diff --git a/plugin/kubernetes/controller.go b/plugin/kubernetes/controller.go
index 00b0c11ac..3262f299e 100644
--- a/plugin/kubernetes/controller.go
+++ b/plugin/kubernetes/controller.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"sync"
+ "sync/atomic"
"time"
api "k8s.io/api/core/v1"
@@ -16,15 +17,15 @@ import (
"k8s.io/apimachinery/pkg/watch"
)
-var (
- namespace = api.NamespaceAll
-)
+var namespace = api.NamespaceAll
-const podIPIndex = "PodIP"
-const svcNameNamespaceIndex = "NameNamespace"
-const svcIPIndex = "ServiceIP"
-const epNameNamespaceIndex = "EndpointNameNamespace"
-const epIPIndex = "EndpointsIP"
+const (
+ podIPIndex = "PodIP"
+ svcNameNamespaceIndex = "NameNamespace"
+ svcIPIndex = "ServiceIP"
+ epNameNamespaceIndex = "EndpointNameNamespace"
+ epIPIndex = "EndpointsIP"
+)
type dnsController interface {
ServiceList() []*api.Service
@@ -41,9 +42,17 @@ type dnsController interface {
Run()
HasSynced() bool
Stop() error
+
+ // Modified returns the timestamp of the most recent changes
+ Modified() int64
}
type dnsControl struct {
+ // Modified tracks timestamp of the most recent changes
+ // It needs to be first because it is guarnteed to be 8-byte
+ // aligned ( we use sync.LoadAtomic with this )
+ modified int64
+
client *kubernetes.Clientset
selector labels.Selector
@@ -86,7 +95,7 @@ func newdnsController(kubeClient *kubernetes.Clientset, opts dnsControlOpts) *dn
},
&api.Service{},
opts.resyncPeriod,
- cache.ResourceEventHandlerFuncs{},
+ cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{svcNameNamespaceIndex: svcNameNamespaceIndexFunc, svcIPIndex: svcIPIndexFunc})
if opts.initPodCache {
@@ -95,9 +104,9 @@ func newdnsController(kubeClient *kubernetes.Clientset, opts dnsControlOpts) *dn
ListFunc: podListFunc(dns.client, namespace, dns.selector),
WatchFunc: podWatchFunc(dns.client, namespace, dns.selector),
},
- &api.Pod{}, // TODO replace with a lighter-weight custom struct
+ &api.Pod{},
opts.resyncPeriod,
- cache.ResourceEventHandlerFuncs{},
+ cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{podIPIndex: podIPIndexFunc})
}
dns.epLister, dns.epController = cache.NewIndexerInformer(
@@ -107,7 +116,7 @@ func newdnsController(kubeClient *kubernetes.Clientset, opts dnsControlOpts) *dn
},
&api.Endpoints{},
opts.resyncPeriod,
- cache.ResourceEventHandlerFuncs{},
+ cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{epNameNamespaceIndex: epNameNamespaceIndexFunc, epIPIndex: epIPIndexFunc})
return &dns
@@ -410,3 +419,18 @@ func (dns *dnsControl) GetNamespaceByName(name string) (*api.Namespace, error) {
}
return v1ns, nil
}
+
+func (dns *dnsControl) Modified() int64 {
+ unix := atomic.LoadInt64(&dns.modified)
+ return unix
+}
+
+// updateModified set dns.modified to the current time.
+func (dns *dnsControl) updateModifed() {
+ unix := time.Now().Unix()
+ atomic.StoreInt64(&dns.modified, unix)
+}
+
+func (dns *dnsControl) Add(obj interface{}) { dns.updateModifed() }
+func (dns *dnsControl) Delete(obj interface{}) { dns.updateModifed() }
+func (dns *dnsControl) Update(objOld, newObj interface{}) { dns.updateModifed() }
diff --git a/plugin/kubernetes/handler.go b/plugin/kubernetes/handler.go
index e02608a6b..5c9ccba34 100644
--- a/plugin/kubernetes/handler.go
+++ b/plugin/kubernetes/handler.go
@@ -53,6 +53,8 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M
break
}
fallthrough
+ case dns.TypeAXFR, dns.TypeIXFR:
+ k.Transfer(ctx, state)
default:
// Do a fake A lookup, so we can distinguish between NODATA and NXDOMAIN
_, err = plugin.A(&k, zone, state, nil, plugin.Options{})
@@ -76,6 +78,7 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M
m.Extra = append(m.Extra, extra...)
m = dnsutil.Dedup(m)
+
state.SizeAndDo(m)
m, _ = state.Scrub(m)
w.WriteMsg(m)
diff --git a/plugin/kubernetes/handler_test.go b/plugin/kubernetes/handler_test.go
index 63e691cf2..b23953b32 100644
--- a/plugin/kubernetes/handler_test.go
+++ b/plugin/kubernetes/handler_test.go
@@ -2,6 +2,7 @@ package kubernetes
import (
"testing"
+ "time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
@@ -35,6 +36,12 @@ var dnsTestCases = []test.Case{
Answer: []dns.RR{test.SRV("svc1.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc1.testns.svc.cluster.local.")},
Extra: []dns.RR{test.A("svc1.testns.svc.cluster.local. 5 IN A 10.0.0.1")},
},
+ {
+ Qname: "svc6.testns.svc.cluster.local.", Qtype: dns.TypeSRV,
+ Rcode: dns.RcodeSuccess,
+ Answer: []dns.RR{test.SRV("svc6.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc6.testns.svc.cluster.local.")},
+ Extra: []dns.RR{test.AAAA("svc6.testns.svc.cluster.local. 5 IN AAAA 1234:abcd::1")},
+ },
// SRV Service (wildcard)
{
Qname: "svc1.*.svc.cluster.local.", Qtype: dns.TypeSRV,
@@ -266,6 +273,7 @@ func (APIConnServeTest) Run() { return }
func (APIConnServeTest) Stop() error { return nil }
func (APIConnServeTest) EpIndexReverse(string) []*api.Endpoints { return nil }
func (APIConnServeTest) SvcIndexReverse(string) []*api.Service { return nil }
+func (APIConnServeTest) Modified() int64 { return time.Now().Unix() }
func (APIConnServeTest) PodIndex(string) []*api.Pod {
a := []*api.Pod{{
@@ -286,6 +294,7 @@ var svcIndex = map[string][]*api.Service{
Namespace: "testns",
},
Spec: api.ServiceSpec{
+ Type: api.ServiceTypeClusterIP,
ClusterIP: "10.0.0.1",
Ports: []api.ServicePort{{
Name: "http",
@@ -300,6 +309,7 @@ var svcIndex = map[string][]*api.Service{
Namespace: "testns",
},
Spec: api.ServiceSpec{
+ Type: api.ServiceTypeClusterIP,
ClusterIP: "1234:abcd::1",
Ports: []api.ServicePort{{
Name: "http",
@@ -314,6 +324,7 @@ var svcIndex = map[string][]*api.Service{
Namespace: "testns",
},
Spec: api.ServiceSpec{
+ Type: api.ServiceTypeClusterIP,
ClusterIP: api.ClusterIPNone,
},
}},
diff --git a/plugin/kubernetes/kubernetes.go b/plugin/kubernetes/kubernetes.go
index 3be10a88e..6afb1d83f 100644
--- a/plugin/kubernetes/kubernetes.go
+++ b/plugin/kubernetes/kubernetes.go
@@ -47,6 +47,7 @@ type Kubernetes struct {
primaryZoneIndex int
interfaceAddrsFunc func() net.IP
autoPathSearch []string // Local search path from /etc/resolv.conf. Needed for autopath.
+ TransferTo []string
}
// New returns a initialized Kubernetes. It default interfaceAddrFunc to return 127.0.0.1. All other
diff --git a/plugin/kubernetes/kubernetes_test.go b/plugin/kubernetes/kubernetes_test.go
index 2eb7330c0..e10fe894b 100644
--- a/plugin/kubernetes/kubernetes_test.go
+++ b/plugin/kubernetes/kubernetes_test.go
@@ -63,6 +63,7 @@ func (APIConnServiceTest) Stop() error { return nil }
func (APIConnServiceTest) PodIndex(string) []*api.Pod { return nil }
func (APIConnServiceTest) SvcIndexReverse(string) []*api.Service { return nil }
func (APIConnServiceTest) EpIndexReverse(string) []*api.Endpoints { return nil }
+func (APIConnServiceTest) Modified() int64 { return 0 }
func (APIConnServiceTest) SvcIndex(string) []*api.Service {
svcs := []*api.Service{
diff --git a/plugin/kubernetes/ns_test.go b/plugin/kubernetes/ns_test.go
index f328bf06c..7dcc83eeb 100644
--- a/plugin/kubernetes/ns_test.go
+++ b/plugin/kubernetes/ns_test.go
@@ -17,6 +17,7 @@ func (APIConnTest) SvcIndex(string) []*api.Service { return nil }
func (APIConnTest) SvcIndexReverse(string) []*api.Service { return nil }
func (APIConnTest) EpIndex(string) []*api.Endpoints { return nil }
func (APIConnTest) EndpointsList() []*api.Endpoints { return nil }
+func (APIConnTest) Modified() int64 { return 0 }
func (APIConnTest) ServiceList() []*api.Service {
svcs := []*api.Service{
diff --git a/plugin/kubernetes/reverse_test.go b/plugin/kubernetes/reverse_test.go
index bfb9e4aa9..22a6219f2 100644
--- a/plugin/kubernetes/reverse_test.go
+++ b/plugin/kubernetes/reverse_test.go
@@ -22,6 +22,7 @@ func (APIConnReverseTest) SvcIndex(string) []*api.Service { return nil }
func (APIConnReverseTest) EpIndex(string) []*api.Endpoints { return nil }
func (APIConnReverseTest) EndpointsList() []*api.Endpoints { return nil }
func (APIConnReverseTest) ServiceList() []*api.Service { return nil }
+func (APIConnReverseTest) Modified() int64 { return 0 }
func (APIConnReverseTest) SvcIndexReverse(ip string) []*api.Service {
if ip != "192.168.1.100" {
diff --git a/plugin/kubernetes/setup.go b/plugin/kubernetes/setup.go
index e8fc484bf..f79724dee 100644
--- a/plugin/kubernetes/setup.go
+++ b/plugin/kubernetes/setup.go
@@ -10,6 +10,7 @@ import (
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
+ "github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/proxy"
"github.com/mholt/caddy"
@@ -197,6 +198,15 @@ func kubernetesParse(c *caddy.Controller) (*Kubernetes, dnsControlOpts, error) {
return nil, opts, c.Errf("ttl must be in range [5, 3600]: %d", t)
}
k8s.ttl = uint32(t)
+ case "transfer":
+ tos, froms, err := parse.Transfer(c, false)
+ if err != nil {
+ return nil, opts, err
+ }
+ if len(froms) != 0 {
+ return nil, opts, c.Errf("transfer from is not supported with this plugin")
+ }
+ k8s.TransferTo = tos
default:
return nil, opts, c.Errf("unknown property '%s'", c.Val())
}
diff --git a/plugin/kubernetes/xfr.go b/plugin/kubernetes/xfr.go
index 7197a1fd5..44d9af70b 100644
--- a/plugin/kubernetes/xfr.go
+++ b/plugin/kubernetes/xfr.go
@@ -1,17 +1,206 @@
package kubernetes
import (
- "time"
+ "log"
+ "math"
+ "net"
+ "strings"
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/request"
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+ api "k8s.io/api/core/v1"
)
+const transferLength = 2000
+
// Serial implements the Transferer interface.
-func (e *Kubernetes) Serial(state request.Request) uint32 {
- return uint32(time.Now().Unix())
-}
+func (k *Kubernetes) Serial(state request.Request) uint32 { return uint32(k.APIConn.Modified()) }
// MinTTL implements the Transferer interface.
-func (e *Kubernetes) MinTTL(state request.Request) uint32 {
- return 30
+func (k *Kubernetes) MinTTL(state request.Request) uint32 { return 30 }
+
+// Transfer implements the Transferer interface.
+func (k *Kubernetes) Transfer(ctx context.Context, state request.Request) (int, error) {
+
+ // Get all services.
+ rrs := make(chan dns.RR)
+ go k.transfer(rrs, state.Zone)
+
+ records := []dns.RR{}
+ for r := range rrs {
+ records = append(records, r)
+ }
+
+ if len(records) == 0 {
+ return dns.RcodeServerFailure, nil
+ }
+
+ ch := make(chan *dns.Envelope)
+ tr := new(dns.Transfer)
+
+ soa, err := plugin.SOA(k, state.Zone, state, plugin.Options{})
+ if err != nil {
+ return dns.RcodeServerFailure, nil
+ }
+
+ records = append(soa, records...)
+ records = append(records, soa...)
+ go func(ch chan *dns.Envelope) {
+ j, l := 0, 0
+ log.Printf("[INFO] Outgoing transfer of %d records of zone %s to %s started", len(records), state.Zone, state.IP())
+ for i, r := range records {
+ l += dns.Len(r)
+ if l > transferLength {
+ ch <- &dns.Envelope{RR: records[j:i]}
+ l = 0
+ j = i
+ }
+ }
+ if j < len(records) {
+ ch <- &dns.Envelope{RR: records[j:]}
+ }
+ close(ch)
+ }(ch)
+
+ tr.Out(state.W, state.Req, ch)
+ // Defer closing to the client
+ state.W.Hijack()
+ return dns.RcodeSuccess, nil
+}
+
+func (k *Kubernetes) transfer(c chan dns.RR, zone string) {
+
+ defer close(c)
+
+ zonePath := msg.Path(zone, "coredns")
+ serviceList := k.APIConn.ServiceList()
+ for _, svc := range serviceList {
+ svcBase := []string{zonePath, Svc, svc.Namespace, svc.Name}
+ switch svc.Spec.Type {
+ case api.ServiceTypeClusterIP, api.ServiceTypeNodePort, api.ServiceTypeLoadBalancer:
+ clusterIP := net.ParseIP(svc.Spec.ClusterIP)
+ if clusterIP != nil {
+ for _, p := range svc.Spec.Ports {
+
+ s := msg.Service{Host: svc.Spec.ClusterIP, Port: int(p.Port), TTL: k.ttl}
+ s.Key = strings.Join(svcBase, "/")
+
+ // Change host from IP to Name for SRV records
+ host := emitAddressRecord(c, s)
+ s.Host = host
+
+ // Need to generate this to handle use cases for peer-finder
+ // ref: https://github.com/coredns/coredns/pull/823
+ c <- s.NewSRV(msg.Domain(s.Key), 100)
+
+ // As per spec unnamed ports do not have a srv record
+ // https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records
+ if p.Name == "" {
+ continue
+ }
+
+ s.Key = strings.Join(append(svcBase, strings.ToLower("_"+string(p.Protocol)), strings.ToLower("_"+string(p.Name))), "/")
+
+ c <- s.NewSRV(msg.Domain(s.Key), 100)
+ }
+
+ // Skip endpoint discovery if clusterIP is defined
+ continue
+ }
+
+ endpointsList := k.APIConn.EpIndex(svc.Name + "." + svc.Namespace)
+
+ for _, ep := range endpointsList {
+ if ep.ObjectMeta.Name != svc.Name || ep.ObjectMeta.Namespace != svc.Namespace {
+ continue
+ }
+
+ for _, eps := range ep.Subsets {
+ srvWeight := calcSRVWeight(len(eps.Addresses))
+ for _, addr := range eps.Addresses {
+ s := msg.Service{Host: addr.IP, TTL: k.ttl}
+ s.Key = strings.Join(svcBase, "/")
+ // We don't need to change the msg.Service host from IP to Name yet
+ // so disregard the return value here
+ emitAddressRecord(c, s)
+
+ s.Key = strings.Join(append(svcBase, endpointHostname(addr, k.endpointNameMode)), "/")
+ // Change host from IP to Name for SRV records
+ host := emitAddressRecord(c, s)
+ s.Host = host
+
+ for _, p := range eps.Ports {
+ // As per spec unnamed ports do not have a srv record
+ // https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records
+ if p.Name == "" {
+ continue
+ }
+
+ s.Port = int(p.Port)
+
+ s.Key = strings.Join(append(svcBase, strings.ToLower("_"+string(p.Protocol)), strings.ToLower("_"+string(p.Name))), "/")
+ c <- s.NewSRV(msg.Domain(s.Key), srvWeight)
+ }
+ }
+ }
+ }
+
+ case api.ServiceTypeExternalName:
+
+ s := msg.Service{Key: strings.Join(svcBase, "/"), Host: svc.Spec.ExternalName, TTL: k.ttl}
+ if t, _ := s.HostType(); t == dns.TypeCNAME {
+ c <- s.NewCNAME(msg.Domain(s.Key), s.Host)
+ }
+ }
+ }
+ return
+}
+
+// emitAddressRecord generates a new A or AAAA record based on the msg.Service and writes it to
+// a channel.
+// emitAddressRecord returns the host name from the generated record.
+func emitAddressRecord(c chan dns.RR, message msg.Service) string {
+ ip := net.ParseIP(message.Host)
+ var host string
+ dnsType, _ := message.HostType()
+ switch dnsType {
+ case dns.TypeA:
+ arec := message.NewA(msg.Domain(message.Key), ip)
+ host = arec.Hdr.Name
+ c <- arec
+ case dns.TypeAAAA:
+ arec := message.NewAAAA(msg.Domain(message.Key), ip)
+ host = arec.Hdr.Name
+ c <- arec
+ }
+
+ return host
+}
+
+// calcSrvWeight borrows the logic implemented in plugin.SRV for dynamically
+// calculating the srv weight and priority
+func calcSRVWeight(numservices int) uint16 {
+ var services []msg.Service
+
+ for i := 0; i < numservices; i++ {
+ services = append(services, msg.Service{})
+ }
+
+ w := make(map[int]int)
+ for _, serv := range services {
+ weight := 100
+ if serv.Weight != 0 {
+ weight = serv.Weight
+ }
+ if _, ok := w[serv.Priority]; !ok {
+ w[serv.Priority] = weight
+ continue
+ }
+ w[serv.Priority] += weight
+ }
+
+ return uint16(math.Floor((100.0 / float64(w[0])) * 100))
}
diff --git a/plugin/kubernetes/xfr_test.go b/plugin/kubernetes/xfr_test.go
new file mode 100644
index 000000000..81be775dc
--- /dev/null
+++ b/plugin/kubernetes/xfr_test.go
@@ -0,0 +1,111 @@
+package kubernetes
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/test"
+ "golang.org/x/net/context"
+
+ "github.com/miekg/dns"
+)
+
+func TestKubernetesXFR(t *testing.T) {
+ k := New([]string{"cluster.local."})
+ k.APIConn = &APIConnServeTest{}
+ k.TransferTo = []string{"127.0.0.1"}
+
+ ctx := context.TODO()
+ w := dnstest.NewMultiRecorder(&test.ResponseWriter{})
+ dnsmsg := &dns.Msg{}
+ dnsmsg.SetAxfr(k.Zones[0])
+
+ _, err := k.ServeDNS(ctx, w, dnsmsg)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if len(w.Msgs) == 0 {
+ t.Logf("%+v\n", w)
+ t.Error("Did not get back a zone response")
+ }
+
+ // Ensure xfr starts with SOA
+ if w.Msgs[0].Answer[0].Header().Rrtype != dns.TypeSOA {
+ t.Error("Invalid XFR, does not start with SOA record")
+ }
+
+ // Ensure xfr starts with SOA
+ // Last message is empty, so we need to go back one further
+ if w.Msgs[len(w.Msgs)-2].Answer[len(w.Msgs[len(w.Msgs)-2].Answer)-1].Header().Rrtype != dns.TypeSOA {
+ t.Error("Invalid XFR, does not end with SOA record")
+ }
+
+ testRRs := []dns.RR{}
+ for _, tc := range dnsTestCases {
+ if tc.Rcode != dns.RcodeSuccess {
+ continue
+ }
+
+ for _, ans := range tc.Answer {
+ // Exclude wildcard searches
+ if strings.Contains(ans.Header().Name, "*") {
+ continue
+ }
+
+ // Exclude TXT records
+ if ans.Header().Rrtype == dns.TypeTXT {
+ continue
+ }
+ testRRs = append(testRRs, ans)
+ }
+ }
+
+ gotRRs := []dns.RR{}
+ for _, resp := range w.Msgs {
+ for _, ans := range resp.Answer {
+ // Skip SOA records since these
+ // test cases do not exist
+ if ans.Header().Rrtype == dns.TypeSOA {
+ continue
+ }
+
+ gotRRs = append(gotRRs, ans)
+ }
+
+ }
+
+ diff := difference(testRRs, gotRRs)
+ if len(diff) != 0 {
+ t.Errorf("Got back %d records that do not exist in test cases, should be 0:", len(diff))
+ for _, rec := range diff {
+ t.Errorf("%+v", rec)
+ }
+ }
+
+ diff = difference(gotRRs, testRRs)
+ if len(diff) != 0 {
+ t.Errorf("Found %d records we're missing tham test cases, should be 0:", len(diff))
+ for _, rec := range diff {
+ t.Errorf("%+v", rec)
+ }
+ }
+
+}
+
+// difference shows what we're missing when comparing two RR slices
+func difference(testRRs []dns.RR, gotRRs []dns.RR) []dns.RR {
+ expectedRRs := map[string]bool{}
+ for _, rr := range testRRs {
+ expectedRRs[rr.String()] = true
+ }
+
+ foundRRs := []dns.RR{}
+ for _, rr := range gotRRs {
+ if _, ok := expectedRRs[rr.String()]; !ok {
+ foundRRs = append(foundRRs, rr)
+ }
+ }
+ return foundRRs
+}