1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
|
package rewrite
import (
"github.com/miekg/dns"
)
// RevertPolicy controls the overall reverting process
type RevertPolicy interface {
DoRevert() bool
DoQuestionRestore() bool
}
type revertPolicy struct {
noRevert bool
noRestore bool
}
func (p revertPolicy) DoRevert() bool {
return !p.noRevert
}
func (p revertPolicy) DoQuestionRestore() bool {
return !p.noRestore
}
// NoRevertPolicy disables all response rewrite rules
func NoRevertPolicy() RevertPolicy {
return revertPolicy{true, false}
}
// NoRestorePolicy disables the question restoration during the response rewrite
func NoRestorePolicy() RevertPolicy {
return revertPolicy{false, true}
}
// NewRevertPolicy creates a new reverter policy by dynamically specifying all
// options.
func NewRevertPolicy(noRevert, noRestore bool) RevertPolicy {
return revertPolicy{noRestore: noRestore, noRevert: noRevert}
}
// ResponseRule contains a rule to rewrite a response with.
type ResponseRule interface {
RewriteResponse(res *dns.Msg, rr dns.RR)
}
// ResponseRules describes an ordered list of response rules to apply
// after a name rewrite
type ResponseRules = []ResponseRule
// ResponseReverter reverses the operations done on the question section of a packet.
// This is need because the client will otherwise disregards the response, i.e.
// dig will complain with ';; Question section mismatch: got example.org/HINFO/IN'
type ResponseReverter struct {
dns.ResponseWriter
originalQuestion dns.Question
ResponseRules ResponseRules
revertPolicy RevertPolicy
}
// NewResponseReverter returns a pointer to a new ResponseReverter.
func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg, policy RevertPolicy) *ResponseReverter {
return &ResponseReverter{
ResponseWriter: w,
originalQuestion: r.Question[0],
revertPolicy: policy,
}
}
// WriteMsg records the status code and calls the underlying ResponseWriter's WriteMsg method.
func (r *ResponseReverter) WriteMsg(res1 *dns.Msg) error {
// Deep copy 'res' as to not (e.g). rewrite a message that's also stored in the cache.
res := res1.Copy()
if r.revertPolicy.DoQuestionRestore() {
res.Question[0] = r.originalQuestion
}
if len(r.ResponseRules) > 0 {
for _, rr := range res.Ns {
r.rewriteResourceRecord(res, rr)
}
for _, rr := range res.Answer {
r.rewriteResourceRecord(res, rr)
}
for _, rr := range res.Extra {
r.rewriteResourceRecord(res, rr)
}
}
return r.ResponseWriter.WriteMsg(res)
}
func (r *ResponseReverter) rewriteResourceRecord(res *dns.Msg, rr dns.RR) {
for _, rule := range r.ResponseRules {
rule.RewriteResponse(res, rr)
}
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseReverter) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
return n, err
}
func getRecordValueForRewrite(rr dns.RR) (name string) {
switch rr.Header().Rrtype {
case dns.TypeSRV:
return rr.(*dns.SRV).Target
case dns.TypeMX:
return rr.(*dns.MX).Mx
case dns.TypeCNAME:
return rr.(*dns.CNAME).Target
case dns.TypeNS:
return rr.(*dns.NS).Ns
case dns.TypeDNAME:
return rr.(*dns.DNAME).Target
case dns.TypeNAPTR:
return rr.(*dns.NAPTR).Replacement
case dns.TypeSOA:
return rr.(*dns.SOA).Ns
case dns.TypePTR:
return rr.(*dns.PTR).Ptr
default:
return ""
}
}
func setRewrittenRecordValue(rr dns.RR, value string) {
switch rr.Header().Rrtype {
case dns.TypeSRV:
rr.(*dns.SRV).Target = value
case dns.TypeMX:
rr.(*dns.MX).Mx = value
case dns.TypeCNAME:
rr.(*dns.CNAME).Target = value
case dns.TypeNS:
rr.(*dns.NS).Ns = value
case dns.TypeDNAME:
rr.(*dns.DNAME).Target = value
case dns.TypeNAPTR:
rr.(*dns.NAPTR).Replacement = value
case dns.TypeSOA:
rr.(*dns.SOA).Ns = value
case dns.TypePTR:
rr.(*dns.PTR).Ptr = value
}
}
|