aboutsummaryrefslogtreecommitdiff
path: root/plugin/rewrite/reverter.go
blob: 570b7d39e8770195efcfdc6f05eeb6c9f24c6882 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package rewrite

import (
	"github.com/miekg/dns"
	"regexp"
	"strconv"
	"strings"
)

// ResponseRule contains a rule to rewrite a response with.
type ResponseRule struct {
	Active      bool
	Type        string
	Pattern     *regexp.Regexp
	Replacement string
	Ttl         uint32
}

// ResponseReverter reverses the operations done on the question section of a packet.
// This is need because the client will otherwise disregards the response, i.e.
// dig will complain with ';; Question section mismatch: got example.org/HINFO/IN'
type ResponseReverter struct {
	dns.ResponseWriter
	originalQuestion dns.Question
	ResponseRewrite  bool
	ResponseRules    []ResponseRule
}

// NewResponseReverter returns a pointer to a new ResponseReverter.
func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter {
	return &ResponseReverter{
		ResponseWriter:   w,
		originalQuestion: r.Question[0],
	}
}

// WriteMsg records the status code and calls the underlying ResponseWriter's WriteMsg method.
func (r *ResponseReverter) WriteMsg(res *dns.Msg) error {
	res.Question[0] = r.originalQuestion
	if r.ResponseRewrite {
		for _, rr := range res.Answer {
			var isNameRewritten bool = false
			var isTtlRewritten bool = false
			var name string = rr.Header().Name
			var ttl uint32 = rr.Header().Ttl
			for _, rule := range r.ResponseRules {
				if rule.Type == "" {
					rule.Type = "name"
				}
				switch rule.Type {
				case "name":
					regexGroups := rule.Pattern.FindStringSubmatch(name)
					if len(regexGroups) == 0 {
						continue
					}
					s := rule.Replacement
					for groupIndex, groupValue := range regexGroups {
						groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
						if strings.Contains(s, groupIndexStr) {
							s = strings.Replace(s, groupIndexStr, groupValue, -1)
						}
					}
					name = s
					isNameRewritten = true
				case "ttl":
					ttl = rule.Ttl
					isTtlRewritten = true
				}
			}
			if isNameRewritten == true {
				rr.Header().Name = name
			}
			if isTtlRewritten == true {
				rr.Header().Ttl = ttl
			}
		}
	}
	return r.ResponseWriter.WriteMsg(res)
}

// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseReverter) Write(buf []byte) (int, error) {
	n, err := r.ResponseWriter.Write(buf)
	return n, err
}