diff options
Diffstat (limited to 'plugin/rewrite/reverter.go')
-rw-r--r-- | plugin/rewrite/reverter.go | 135 |
1 files changed, 59 insertions, 76 deletions
diff --git a/plugin/rewrite/reverter.go b/plugin/rewrite/reverter.go index 49222ddfc..7d83e557b 100644 --- a/plugin/rewrite/reverter.go +++ b/plugin/rewrite/reverter.go @@ -1,37 +1,69 @@ package rewrite import ( - "regexp" - "strconv" - "strings" - "github.com/miekg/dns" ) +// RevertPolicy controls the overall reverting process +type RevertPolicy interface { + DoRevert() bool + DoQuestionRestore() bool +} + +type revertPolicy struct { + noRevert bool + noRestore bool +} + +func (p revertPolicy) DoRevert() bool { + return !p.noRevert +} + +func (p revertPolicy) DoQuestionRestore() bool { + return !p.noRestore +} + +// NoRevertPolicy disables all response rewrite rules +func NoRevertPolicy() RevertPolicy { + return revertPolicy{true, false} +} + +// NoRestorePolicy disables the question restoration during the response rewrite +func NoRestorePolicy() RevertPolicy { + return revertPolicy{false, true} +} + +// NewRevertPolicy creates a new reverter policy by dynamically specifying all +// options. +func NewRevertPolicy(noRevert, noRestore bool) RevertPolicy { + return revertPolicy{noRestore: noRestore, noRevert: noRevert} +} + // ResponseRule contains a rule to rewrite a response with. -type ResponseRule struct { - Active bool - Type string - Pattern *regexp.Regexp - Replacement string - TTL uint32 +type ResponseRule interface { + RewriteResponse(rr dns.RR) } +// ResponseRules describes an ordered list of response rules to apply +// after a name rewrite +type ResponseRules = []ResponseRule + // ResponseReverter reverses the operations done on the question section of a packet. // This is need because the client will otherwise disregards the response, i.e. // dig will complain with ';; Question section mismatch: got example.org/HINFO/IN' type ResponseReverter struct { dns.ResponseWriter originalQuestion dns.Question - ResponseRewrite bool - ResponseRules []ResponseRule + ResponseRules ResponseRules + revertPolicy RevertPolicy } // NewResponseReverter returns a pointer to a new ResponseReverter. -func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter { +func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg, policy RevertPolicy) *ResponseReverter { return &ResponseReverter{ ResponseWriter: w, originalQuestion: r.Question[0], + revertPolicy: policy, } } @@ -40,61 +72,33 @@ func (r *ResponseReverter) WriteMsg(res1 *dns.Msg) error { // Deep copy 'res' as to not (e.g). rewrite a message that's also stored in the cache. res := res1.Copy() - res.Question[0] = r.originalQuestion - if r.ResponseRewrite { + if r.revertPolicy.DoQuestionRestore() { + res.Question[0] = r.originalQuestion + } + if len(r.ResponseRules) > 0 { for _, rr := range res.Ns { - rewriteResourceRecord(res, rr, r) + r.rewriteResourceRecord(res, rr) } - for _, rr := range res.Answer { - rewriteResourceRecord(res, rr, r) + r.rewriteResourceRecord(res, rr) } - for _, rr := range res.Extra { - rewriteResourceRecord(res, rr, r) + r.rewriteResourceRecord(res, rr) } - } return r.ResponseWriter.WriteMsg(res) } -func rewriteResourceRecord(res *dns.Msg, rr dns.RR, r *ResponseReverter) { - var ( - isNameRewritten bool - isTTLRewritten bool - isValueRewritten bool - name = rr.Header().Name - ttl = rr.Header().Ttl - value string - ) - +func (r *ResponseReverter) rewriteResourceRecord(res *dns.Msg, rr dns.RR) { for _, rule := range r.ResponseRules { - if rule.Type == "" { - rule.Type = "name" - } - switch rule.Type { - case "name": - rewriteString(rule, &name, &isNameRewritten) - case "value": - value = getRecordValueForRewrite(rr) - if value != "" { - rewriteString(rule, &value, &isValueRewritten) - } - case "ttl": - ttl = rule.TTL - isTTLRewritten = true - } + rule.RewriteResponse(rr) } +} - if isNameRewritten { - rr.Header().Name = name - } - if isTTLRewritten { - rr.Header().Ttl = ttl - } - if isValueRewritten { - setRewrittenRecordValue(rr, value) - } +// Write is a wrapper that records the size of the message that gets written. +func (r *ResponseReverter) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + return n, err } func getRecordValueForRewrite(rr dns.RR) (name string) { @@ -136,24 +140,3 @@ func setRewrittenRecordValue(rr dns.RR, value string) { rr.(*dns.SOA).Ns = value } } - -func rewriteString(rule ResponseRule, str *string, isStringRewritten *bool) { - regexGroups := rule.Pattern.FindStringSubmatch(*str) - if len(regexGroups) == 0 { - return - } - s := rule.Replacement - for groupIndex, groupValue := range regexGroups { - groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}" - s = strings.Replace(s, groupIndexStr, groupValue, -1) - } - - *isStringRewritten = true - *str = s -} - -// Write is a wrapper that records the size of the message that gets written. -func (r *ResponseReverter) Write(buf []byte) (int, error) { - n, err := r.ResponseWriter.Write(buf) - return n, err -} |