aboutsummaryrefslogtreecommitdiff
path: root/plugin/tsig/tsig_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/tsig/tsig_test.go')
-rw-r--r--plugin/tsig/tsig_test.go255
1 files changed, 255 insertions, 0 deletions
diff --git a/plugin/tsig/tsig_test.go b/plugin/tsig/tsig_test.go
new file mode 100644
index 000000000..f7ec1fdf1
--- /dev/null
+++ b/plugin/tsig/tsig_test.go
@@ -0,0 +1,255 @@
+package tsig
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/test"
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+func TestServeDNS(t *testing.T) {
+ cases := []struct {
+ zones []string
+ reqTypes qTypes
+ qType uint16
+ qTsig, all bool
+ expectRcode int
+ expectTsig bool
+ statusError bool
+ }{
+ {
+ zones: []string{"."},
+ all: true,
+ qType: dns.TypeA,
+ qTsig: true,
+ expectRcode: dns.RcodeSuccess,
+ expectTsig: true,
+ },
+ {
+ zones: []string{"."},
+ all: true,
+ qType: dns.TypeA,
+ qTsig: false,
+ expectRcode: dns.RcodeRefused,
+ expectTsig: false,
+ },
+ {
+ zones: []string{"another.domain."},
+ all: true,
+ qType: dns.TypeA,
+ qTsig: false,
+ expectRcode: dns.RcodeSuccess,
+ expectTsig: false,
+ },
+ {
+ zones: []string{"another.domain."},
+ all: true,
+ qType: dns.TypeA,
+ qTsig: true,
+ expectRcode: dns.RcodeSuccess,
+ expectTsig: false,
+ },
+ {
+ zones: []string{"."},
+ reqTypes: qTypes{dns.TypeAXFR: {}},
+ qType: dns.TypeAXFR,
+ qTsig: true,
+ expectRcode: dns.RcodeSuccess,
+ expectTsig: true,
+ },
+ {
+ zones: []string{"."},
+ reqTypes: qTypes{},
+ qType: dns.TypeA,
+ qTsig: false,
+ expectRcode: dns.RcodeSuccess,
+ expectTsig: false,
+ },
+ {
+ zones: []string{"."},
+ reqTypes: qTypes{},
+ qType: dns.TypeA,
+ qTsig: true,
+ expectRcode: dns.RcodeSuccess,
+ expectTsig: true,
+ },
+ {
+ zones: []string{"."},
+ all: true,
+ qType: dns.TypeA,
+ qTsig: true,
+ expectRcode: dns.RcodeNotAuth,
+ expectTsig: true,
+ statusError: true,
+ },
+ }
+
+ for i, tc := range cases {
+ t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
+ tsig := TSIGServer{
+ Zones: tc.zones,
+ all: tc.all,
+ types: tc.reqTypes,
+ Next: testHandler(),
+ }
+
+ ctx := context.TODO()
+
+ var w *dnstest.Recorder
+ if tc.statusError {
+ w = dnstest.NewRecorder(&ErrWriter{err: dns.ErrSig})
+ } else {
+ w = dnstest.NewRecorder(&test.ResponseWriter{})
+ }
+ r := new(dns.Msg)
+ r.SetQuestion("test.example.", tc.qType)
+ if tc.qTsig {
+ r.SetTsig("test.key.", dns.HmacSHA256, 300, time.Now().Unix())
+ }
+
+ _, err := tsig.ServeDNS(ctx, w, r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if w.Msg.Rcode != tc.expectRcode {
+ t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode)
+ }
+
+ if ts := w.Msg.IsTsig(); ts == nil && tc.expectTsig {
+ t.Fatal("expected TSIG in response")
+ }
+ if ts := w.Msg.IsTsig(); ts != nil && !tc.expectTsig {
+ t.Fatal("expected no TSIG in response")
+ }
+ })
+ }
+}
+
+func TestServeDNSTsigErrors(t *testing.T) {
+ clientNow := time.Now().Unix()
+
+ cases := []struct {
+ desc string
+ tsigErr error
+ expectRcode int
+ expectError int
+ expectOtherLength int
+ expectTimeSigned int64
+ }{
+ {
+ desc: "Unknown Key",
+ tsigErr: dns.ErrSecret,
+ expectRcode: dns.RcodeNotAuth,
+ expectError: dns.RcodeBadKey,
+ expectOtherLength: 0,
+ expectTimeSigned: 0,
+ },
+ {
+ desc: "Bad Signature",
+ tsigErr: dns.ErrSig,
+ expectRcode: dns.RcodeNotAuth,
+ expectError: dns.RcodeBadSig,
+ expectOtherLength: 0,
+ expectTimeSigned: 0,
+ },
+ {
+ desc: "Bad Time",
+ tsigErr: dns.ErrTime,
+ expectRcode: dns.RcodeNotAuth,
+ expectError: dns.RcodeBadTime,
+ expectOtherLength: 6,
+ expectTimeSigned: clientNow,
+ },
+ }
+
+ tsig := TSIGServer{
+ Zones: []string{"."},
+ all: true,
+ Next: testHandler(),
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.desc, func(t *testing.T) {
+ ctx := context.TODO()
+
+ var w *dnstest.Recorder
+
+ w = dnstest.NewRecorder(&ErrWriter{err: tc.tsigErr})
+
+ r := new(dns.Msg)
+ r.SetQuestion("test.example.", dns.TypeA)
+ r.SetTsig("test.key.", dns.HmacSHA256, 300, clientNow)
+
+ // set a fake MAC and Size in request
+ rtsig := r.IsTsig()
+ rtsig.MAC = "0123456789012345678901234567890101234567890123456789012345678901"
+ rtsig.MACSize = 32
+
+ _, err := tsig.ServeDNS(ctx, w, r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if w.Msg.Rcode != tc.expectRcode {
+ t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode)
+ }
+
+ ts := w.Msg.IsTsig()
+
+ if ts == nil {
+ t.Fatal("expected TSIG in response")
+ }
+
+ if int(ts.Error) != tc.expectError {
+ t.Errorf("expected TSIG error code %v, got %v", tc.expectError, ts.Error)
+ }
+
+ if len(ts.OtherData)/2 != tc.expectOtherLength {
+ t.Errorf("expected Other of length %v, got %v", tc.expectOtherLength, len(ts.OtherData))
+ }
+
+ if int(ts.OtherLen) != tc.expectOtherLength {
+ t.Errorf("expected OtherLen %v, got %v", tc.expectOtherLength, ts.OtherLen)
+ }
+
+ if ts.TimeSigned != uint64(tc.expectTimeSigned) {
+ t.Errorf("expected TimeSigned to be %v, got %v", tc.expectTimeSigned, ts.TimeSigned)
+ }
+ })
+ }
+}
+
+func testHandler() test.HandlerFunc {
+ return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ state := request.Request{W: w, Req: r}
+ qname := state.Name()
+ m := new(dns.Msg)
+ rcode := dns.RcodeServerFailure
+ if qname == "test.example." {
+ m.SetReply(r)
+ rr := test.A("test.example. 300 IN A 1.2.3.48")
+ m.Answer = []dns.RR{rr}
+ m.Authoritative = true
+ rcode = dns.RcodeSuccess
+ }
+ m.SetRcode(r, rcode)
+ w.WriteMsg(m)
+ return rcode, nil
+ }
+}
+
+// a test.ResponseWriter that always returns err as the TSIG status error
+type ErrWriter struct {
+ err error
+ test.ResponseWriter
+}
+
+// TsigStatus always returns an error.
+func (t *ErrWriter) TsigStatus() error { return t.err }