aboutsummaryrefslogtreecommitdiff
path: root/plugin/tsig/tsig.go
blob: 6441c8a6b75ff730a83038b08139b6953e0d90b6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package tsig

import (
	"context"
	"encoding/binary"
	"encoding/hex"
	"time"

	"github.com/coredns/coredns/plugin"
	"github.com/coredns/coredns/plugin/pkg/log"
	"github.com/coredns/coredns/request"

	"github.com/miekg/dns"
)

// TSIGServer verifies tsig status and adds tsig to responses
type TSIGServer struct {
	Zones   []string
	secrets map[string]string // [key-name]secret
	types   qTypes
	all     bool
	Next    plugin.Handler
}

type qTypes map[uint16]struct{}

// Name implements plugin.Handler
func (t TSIGServer) Name() string { return pluginName }

// ServeDNS implements plugin.Handler
func (t *TSIGServer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
	var err error
	state := request.Request{Req: r, W: w}
	if z := plugin.Zones(t.Zones).Matches(state.Name()); z == "" {
		return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
	}

	var tsigRR = r.IsTsig()
	rcode := dns.RcodeSuccess
	if !t.tsigRequired(state.QType()) && tsigRR == nil {
		return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
	}

	if tsigRR == nil {
		log.Debugf("rejecting '%s' request without TSIG\n", dns.TypeToString[state.QType()])
		rcode = dns.RcodeRefused
	}

	// wrap the response writer so the response will be TSIG signed.
	w = &restoreTsigWriter{w, r, tsigRR}

	tsigStatus := w.TsigStatus()
	if tsigStatus != nil {
		log.Debugf("TSIG validation failed: %v %v", dns.TypeToString[state.QType()], tsigStatus)
		rcode = dns.RcodeNotAuth
		switch tsigStatus {
		case dns.ErrSecret:
			tsigRR.Error = dns.RcodeBadKey
		case dns.ErrTime:
			tsigRR.Error = dns.RcodeBadTime
		default:
			tsigRR.Error = dns.RcodeBadSig
		}
		resp := new(dns.Msg).SetRcode(r, rcode)
		w.WriteMsg(resp)
		return dns.RcodeSuccess, nil
	}

	// strip the TSIG RR. Next, and subsequent plugins will not see the TSIG RRs.
	// This violates forwarding cases (RFC 8945 5.5). See README.md Bugs
	if len(r.Extra) > 1 {
		r.Extra = r.Extra[0 : len(r.Extra)-1]
	} else {
		r.Extra = []dns.RR{}
	}

	if rcode == dns.RcodeSuccess {
		rcode, err = plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
		if err != nil {
			log.Errorf("request handler returned an error: %v\n", err)
		}
	}
	// If the plugin chain result was not an error, restore the TSIG and write the response.
	if !plugin.ClientWrite(rcode) {
		resp := new(dns.Msg).SetRcode(r, rcode)
		w.WriteMsg(resp)
	}
	return dns.RcodeSuccess, nil
}

func (t *TSIGServer) tsigRequired(qtype uint16) bool {
	if t.all {
		return true
	}
	if _, ok := t.types[qtype]; ok {
		return true
	}
	return false
}

// restoreTsigWriter Implement Response Writer, and adds a TSIG RR to a response
type restoreTsigWriter struct {
	dns.ResponseWriter
	req     *dns.Msg  // original request excluding TSIG if it has one
	reqTSIG *dns.TSIG // original TSIG
}

// WriteMsg adds a TSIG RR to the response
func (r *restoreTsigWriter) WriteMsg(m *dns.Msg) error {
	// Make sure the response has an EDNS OPT RR if the request had it.
	// Otherwise ScrubWriter would append it *after* TSIG, making it a non-compliant DNS message.
	state := request.Request{Req: r.req, W: r.ResponseWriter}
	state.SizeAndDo(m)

	repTSIG := m.IsTsig()
	if r.reqTSIG != nil && repTSIG == nil {
		repTSIG = new(dns.TSIG)
		repTSIG.Hdr = dns.RR_Header{Name: r.reqTSIG.Hdr.Name, Rrtype: dns.TypeTSIG, Class: dns.ClassANY}
		repTSIG.Algorithm = r.reqTSIG.Algorithm
		repTSIG.OrigId = m.MsgHdr.Id
		repTSIG.Error = r.reqTSIG.Error
		repTSIG.MAC = r.reqTSIG.MAC
		repTSIG.MACSize = r.reqTSIG.MACSize
		if repTSIG.Error == dns.RcodeBadTime {
			// per RFC 8945 5.2.3. client time goes into TimeSigned, server time in OtherData, OtherLen = 6 ...
			repTSIG.TimeSigned = r.reqTSIG.TimeSigned
			b := make([]byte, 8)
			// TimeSigned is network byte order.
			binary.BigEndian.PutUint64(b, uint64(time.Now().Unix()))
			// truncate to 48 least significant bits (network order 6 rightmost bytes)
			repTSIG.OtherData = hex.EncodeToString(b[2:])
			repTSIG.OtherLen = 6
		}
		m.Extra = append(m.Extra, repTSIG)
	}

	return r.ResponseWriter.WriteMsg(m)
}

const pluginName = "tsig"