diff options
Diffstat (limited to 'plugin/rewrite')
-rw-r--r-- | plugin/rewrite/README.md | 91 | ||||
-rw-r--r-- | plugin/rewrite/class.go | 35 | ||||
-rw-r--r-- | plugin/rewrite/condition.go | 132 | ||||
-rw-r--r-- | plugin/rewrite/condition_test.go | 102 | ||||
-rw-r--r-- | plugin/rewrite/edns0.go | 425 | ||||
-rw-r--r-- | plugin/rewrite/name.go | 24 | ||||
-rw-r--r-- | plugin/rewrite/reverter.go | 39 | ||||
-rw-r--r-- | plugin/rewrite/rewrite.go | 86 | ||||
-rw-r--r-- | plugin/rewrite/rewrite_test.go | 532 | ||||
-rw-r--r-- | plugin/rewrite/setup.go | 42 | ||||
-rw-r--r-- | plugin/rewrite/setup_test.go | 25 | ||||
-rw-r--r-- | plugin/rewrite/testdata/testdir/empty | 0 | ||||
-rw-r--r-- | plugin/rewrite/testdata/testfile | 1 | ||||
-rw-r--r-- | plugin/rewrite/type.go | 37 |
14 files changed, 1571 insertions, 0 deletions
diff --git a/plugin/rewrite/README.md b/plugin/rewrite/README.md new file mode 100644 index 000000000..63334d09c --- /dev/null +++ b/plugin/rewrite/README.md @@ -0,0 +1,91 @@ +# rewrite + +*rewrite* performs internal message rewriting. + +Rewrites are invisible to the client. There are simple rewrites (fast) and complex rewrites +(slower), but they're powerful enough to accommodate most dynamic back-end applications. + +## Syntax + +~~~ +rewrite FIELD FROM TO +~~~ + +* **FIELD** is (`type`, `class`, `name`, ...) +* **FROM** is the exact name of type to match +* **TO** is the destination name or type to rewrite to + +When the FIELD is `type` and FROM is (`A`, `MX`, etc.), the type of the message will be rewritten; +e.g., to rewrite ANY queries to HINFO, use `rewrite type ANY HINFO`. + +When the FIELD is `class` and FROM is (`IN`, `CH`, or `HS`) the class of the message will be +rewritten; e.g., to rewrite CH queries to IN use `rewrite class CH IN`. + +When the FIELD is `name` the query name in the message is rewritten; this +needs to be a full match of the name, e.g., `rewrite name miek.nl example.org`. + +When the FIELD is `edns0` an EDNS0 option can be appended to the request as described below. + +If you specify multiple rules and an incoming query matches on multiple (simple) rules, only +the first rewrite is applied. + +## EDNS0 Options + +Using FIELD edns0, you can set, append, or replace specific EDNS0 options on the request. + +* `replace` will modify any matching (what that means may vary based on EDNS0 type) option with the specified option +* `append` will add the option regardless of what options already exist +* `set` will modify a matching option or add one if none is found + +Currently supported are `EDNS0_LOCAL`, `EDNS0_NSID` and `EDNS0_SUBNET`. + +### `EDNS0_LOCAL` + +This has two fields, code and data. A match is defined as having the same code. Data may be a string or a variable. + +* A string data can be treated as hex if it starts with `0x`. Example: + +~~~ +rewrite edns0 local set 0xffee 0x61626364 +~~~ + +rewrites the first local option with code 0xffee, setting the data to "abcd". Equivalent: + +~~~ +rewrite edns0 local set 0xffee abcd +~~~ + +* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables: + * {qname} + * {qtype} + * {client_ip} + * {client_port} + * {protocol} + * {server_ip} + * {server_port} + +Example: + +~~~ +rewrite edns0 local set 0xffee {client_ip} +~~~ + +### `EDNS0_NSID` + +This has no fields; it will add an NSID option with an empty string for the NSID. If the option already exists +and the action is `replace` or `set`, then the NSID in the option will be set to the empty string. + +### `EDNS0_SUBNET` + +This has two fields, IPv4 bitmask length and IPv6 bitmask length. The bitmask +length is used to extract the client subnet from the source IP address in the query. + +Example: + +~~~ + rewrite edns0 subnet set 24 56 +~~~ + +* If the query has source IP as IPv4, the first 24 bits in the IP will be the network subnet. +* If the query has source IP as IPv6, the first 56 bits in the IP will be the network subnet. + diff --git a/plugin/rewrite/class.go b/plugin/rewrite/class.go new file mode 100644 index 000000000..8cc7d26b7 --- /dev/null +++ b/plugin/rewrite/class.go @@ -0,0 +1,35 @@ +package rewrite + +import ( + "fmt" + "strings" + + "github.com/miekg/dns" +) + +type classRule struct { + fromClass, toClass uint16 +} + +func newClassRule(fromS, toS string) (Rule, error) { + var from, to uint16 + var ok bool + if from, ok = dns.StringToClass[strings.ToUpper(fromS)]; !ok { + return nil, fmt.Errorf("invalid class %q", strings.ToUpper(fromS)) + } + if to, ok = dns.StringToClass[strings.ToUpper(toS)]; !ok { + return nil, fmt.Errorf("invalid class %q", strings.ToUpper(toS)) + } + return &classRule{fromClass: from, toClass: to}, nil +} + +// Rewrite rewrites the the current request. +func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + if rule.fromClass > 0 && rule.toClass > 0 { + if r.Question[0].Qclass == rule.fromClass { + r.Question[0].Qclass = rule.toClass + return RewriteDone + } + } + return RewriteIgnored +} diff --git a/plugin/rewrite/condition.go b/plugin/rewrite/condition.go new file mode 100644 index 000000000..2f20d71aa --- /dev/null +++ b/plugin/rewrite/condition.go @@ -0,0 +1,132 @@ +package rewrite + +import ( + "fmt" + "regexp" + "strings" + + "github.com/coredns/coredns/plugin/pkg/replacer" + + "github.com/miekg/dns" +) + +// Operators +const ( + Is = "is" + Not = "not" + Has = "has" + NotHas = "not_has" + StartsWith = "starts_with" + EndsWith = "ends_with" + Match = "match" + NotMatch = "not_match" +) + +func operatorError(operator string) error { + return fmt.Errorf("invalid operator %v", operator) +} + +func newReplacer(r *dns.Msg) replacer.Replacer { + return replacer.New(r, nil, "") +} + +// condition is a rewrite condition. +type condition func(string, string) bool + +var conditions = map[string]condition{ + Is: isFunc, + Not: notFunc, + Has: hasFunc, + NotHas: notHasFunc, + StartsWith: startsWithFunc, + EndsWith: endsWithFunc, + Match: matchFunc, + NotMatch: notMatchFunc, +} + +// isFunc is condition for Is operator. +// It checks for equality. +func isFunc(a, b string) bool { + return a == b +} + +// notFunc is condition for Not operator. +// It checks for inequality. +func notFunc(a, b string) bool { + return a != b +} + +// hasFunc is condition for Has operator. +// It checks if b is a substring of a. +func hasFunc(a, b string) bool { + return strings.Contains(a, b) +} + +// notHasFunc is condition for NotHas operator. +// It checks if b is not a substring of a. +func notHasFunc(a, b string) bool { + return !strings.Contains(a, b) +} + +// startsWithFunc is condition for StartsWith operator. +// It checks if b is a prefix of a. +func startsWithFunc(a, b string) bool { + return strings.HasPrefix(a, b) +} + +// endsWithFunc is condition for EndsWith operator. +// It checks if b is a suffix of a. +func endsWithFunc(a, b string) bool { + // TODO(miek): IsSubDomain + return strings.HasSuffix(a, b) +} + +// matchFunc is condition for Match operator. +// It does regexp matching of a against pattern in b +// and returns if they match. +func matchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return matched +} + +// notMatchFunc is condition for NotMatch operator. +// It does regexp matching of a against pattern in b +// and returns if they do not match. +func notMatchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return !matched +} + +// If is statement for a rewrite condition. +type If struct { + A string + Operator string + B string +} + +// True returns true if the condition is true and false otherwise. +// If r is not nil, it replaces placeholders before comparison. +func (i If) True(r *dns.Msg) bool { + if c, ok := conditions[i.Operator]; ok { + a, b := i.A, i.B + if r != nil { + replacer := newReplacer(r) + a = replacer.Replace(i.A) + b = replacer.Replace(i.B) + } + return c(a, b) + } + return false +} + +// NewIf creates a new If condition. +func NewIf(a, operator, b string) (If, error) { + if _, ok := conditions[operator]; !ok { + return If{}, operatorError(operator) + } + return If{ + A: a, + Operator: operator, + B: b, + }, nil +} diff --git a/plugin/rewrite/condition_test.go b/plugin/rewrite/condition_test.go new file mode 100644 index 000000000..91004f9d7 --- /dev/null +++ b/plugin/rewrite/condition_test.go @@ -0,0 +1,102 @@ +package rewrite + +/* +func TestConditions(t *testing.T) { + tests := []struct { + condition string + isTrue bool + }{ + {"a is b", false}, + {"a is a", true}, + {"a not b", true}, + {"a not a", false}, + {"a has a", true}, + {"a has b", false}, + {"ba has b", true}, + {"bab has b", true}, + {"bab has bb", false}, + {"a not_has a", false}, + {"a not_has b", true}, + {"ba not_has b", false}, + {"bab not_has b", false}, + {"bab not_has bb", true}, + {"bab starts_with bb", false}, + {"bab starts_with ba", true}, + {"bab starts_with bab", true}, + {"bab ends_with bb", false}, + {"bab ends_with bab", true}, + {"bab ends_with ab", true}, + {"a match *", false}, + {"a match a", true}, + {"a match .*", true}, + {"a match a.*", true}, + {"a match b.*", false}, + {"ba match b.*", true}, + {"ba match b[a-z]", true}, + {"b0 match b[a-z]", false}, + {"b0a match b[a-z]", false}, + {"b0a match b[a-z]+", false}, + {"b0a match b[a-z0-9]+", true}, + {"a not_match *", true}, + {"a not_match a", false}, + {"a not_match .*", false}, + {"a not_match a.*", false}, + {"a not_match b.*", true}, + {"ba not_match b.*", false}, + {"ba not_match b[a-z]", false}, + {"b0 not_match b[a-z]", true}, + {"b0a not_match b[a-z]", true}, + {"b0a not_match b[a-z]+", true}, + {"b0a not_match b[a-z0-9]+", false}, + } + + for i, test := range tests { + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(nil) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } + + invalidOperators := []string{"ss", "and", "if"} + for _, op := range invalidOperators { + _, err := NewIf("a", op, "b") + if err == nil { + t.Errorf("Invalid operator %v used, expected error.", op) + } + } + + replaceTests := []struct { + url string + condition string + isTrue bool + }{ + {"/home", "{uri} match /home", true}, + {"/hom", "{uri} match /home", false}, + {"/hom", "{uri} starts_with /home", false}, + {"/hom", "{uri} starts_with /h", true}, + {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true}, + {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true}, + } + + for i, test := range replaceTests { + r, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Error(err) + } + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(r) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } +} +*/ diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go new file mode 100644 index 000000000..d8b6f4128 --- /dev/null +++ b/plugin/rewrite/edns0.go @@ -0,0 +1,425 @@ +// Package rewrite is plugin for rewriting requests internally to something different. +package rewrite + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "net" + "strconv" + "strings" + + "github.com/coredns/coredns/request" + "github.com/miekg/dns" +) + +// edns0LocalRule is a rewrite rule for EDNS0_LOCAL options +type edns0LocalRule struct { + action string + code uint16 + data []byte +} + +// edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable +type edns0VariableRule struct { + action string + code uint16 + variable string +} + +// ends0NsidRule is a rewrite rule for EDNS0_NSID options +type edns0NsidRule struct { + action string +} + +// setupEdns0Opt will retrieve the EDNS0 OPT or create it if it does not exist +func setupEdns0Opt(r *dns.Msg) *dns.OPT { + o := r.IsEdns0() + if o == nil { + r.SetEdns0(4096, true) + o = r.IsEdns0() + } + return o +} + +// Rewrite will alter the request EDNS0 NSID option +func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + result := RewriteIgnored + o := setupEdns0Opt(r) + found := false +Option: + for _, s := range o.Option { + switch e := s.(type) { + case *dns.EDNS0_NSID: + if rule.action == Replace || rule.action == Set { + e.Nsid = "" // make sure it is empty for request + result = RewriteDone + } + found = true + break Option + } + } + + // add option if not found + if !found && (rule.action == Append || rule.action == Set) { + o.SetDo() + o.Option = append(o.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}) + result = RewriteDone + } + + return result +} + +// Rewrite will alter the request EDNS0 local options +func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + result := RewriteIgnored + o := setupEdns0Opt(r) + found := false + for _, s := range o.Option { + switch e := s.(type) { + case *dns.EDNS0_LOCAL: + if rule.code == e.Code { + if rule.action == Replace || rule.action == Set { + e.Data = rule.data + result = RewriteDone + } + found = true + break + } + } + } + + // add option if not found + if !found && (rule.action == Append || rule.action == Set) { + o.SetDo() + var opt dns.EDNS0_LOCAL + opt.Code = rule.code + opt.Data = rule.data + o.Option = append(o.Option, &opt) + result = RewriteDone + } + + return result +} + +// newEdns0Rule creates an EDNS0 rule of the appropriate type based on the args +func newEdns0Rule(args ...string) (Rule, error) { + if len(args) < 2 { + return nil, fmt.Errorf("too few arguments for an EDNS0 rule") + } + + ruleType := strings.ToLower(args[0]) + action := strings.ToLower(args[1]) + switch action { + case Append: + case Replace: + case Set: + default: + return nil, fmt.Errorf("invalid action: %q", action) + } + + switch ruleType { + case "local": + if len(args) != 4 { + return nil, fmt.Errorf("EDNS0 local rules require exactly three args") + } + //Check for variable option + if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { + return newEdns0VariableRule(action, args[2], args[3]) + } + return newEdns0LocalRule(action, args[2], args[3]) + case "nsid": + if len(args) != 2 { + return nil, fmt.Errorf("EDNS0 NSID rules do not accept args") + } + return &edns0NsidRule{action: action}, nil + case "subnet": + if len(args) != 4 { + return nil, fmt.Errorf("EDNS0 subnet rules require exactly three args") + } + return newEdns0SubnetRule(action, args[2], args[3]) + default: + return nil, fmt.Errorf("invalid rule type %q", ruleType) + } +} + +func newEdns0LocalRule(action, code, data string) (*edns0LocalRule, error) { + c, err := strconv.ParseUint(code, 0, 16) + if err != nil { + return nil, err + } + + decoded := []byte(data) + if strings.HasPrefix(data, "0x") { + decoded, err = hex.DecodeString(data[2:]) + if err != nil { + return nil, err + } + } + return &edns0LocalRule{action: action, code: uint16(c), data: decoded}, nil +} + +// newEdns0VariableRule creates an EDNS0 rule that handles variable substitution +func newEdns0VariableRule(action, code, variable string) (*edns0VariableRule, error) { + c, err := strconv.ParseUint(code, 0, 16) + if err != nil { + return nil, err + } + //Validate + if !isValidVariable(variable) { + return nil, fmt.Errorf("unsupported variable name %q", variable) + } + return &edns0VariableRule{action: action, code: uint16(c), variable: variable}, nil +} + +// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. +func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) { + + switch family { + case 1: + return net.ParseIP(ipAddr).To4(), nil + case 2: + return net.ParseIP(ipAddr).To16(), nil + } + return nil, fmt.Errorf("Invalid IP address family (i.e. version) %d", family) +} + +// uint16ToWire writes unit16 to wire/binary format +func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(data)) + return buf +} + +// portToWire writes port to wire/binary format, 2 bytes +func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) { + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + return rule.uint16ToWire(uint16(port)), nil +} + +// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. +func (rule *edns0VariableRule) family(ip net.Addr) int { + var a net.IP + if i, ok := ip.(*net.UDPAddr); ok { + a = i.IP + } + if i, ok := ip.(*net.TCPAddr); ok { + a = i.IP + } + if a.To4() != nil { + return 1 + } + return 2 +} + +// ruleData returns the data specified by the variable +func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { + + req := request.Request{W: w, Req: r} + switch rule.variable { + case queryName: + //Query name is written as ascii string + return []byte(req.QName()), nil + + case queryType: + return rule.uint16ToWire(req.QType()), nil + + case clientIP: + return rule.ipToWire(req.Family(), req.IP()) + + case clientPort: + return rule.portToWire(req.Port()) + + case protocol: + // Proto is written as ascii string + return []byte(req.Proto()), nil + + case serverIP: + ip, _, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + ip = w.RemoteAddr().String() + } + return rule.ipToWire(rule.family(w.RemoteAddr()), ip) + + case serverPort: + _, port, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + port = "0" + } + return rule.portToWire(port) + } + + return nil, fmt.Errorf("Unable to extract data for variable %s", rule.variable) +} + +// Rewrite will alter the request EDNS0 local options with specified variables +func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + result := RewriteIgnored + + data, err := rule.ruleData(w, r) + if err != nil || data == nil { + return result + } + + o := setupEdns0Opt(r) + found := false + for _, s := range o.Option { + switch e := s.(type) { + case *dns.EDNS0_LOCAL: + if rule.code == e.Code { + if rule.action == Replace || rule.action == Set { + e.Data = data + result = RewriteDone + } + found = true + break + } + } + } + + // add option if not found + if !found && (rule.action == Append || rule.action == Set) { + o.SetDo() + var opt dns.EDNS0_LOCAL + opt.Code = rule.code + opt.Data = data + o.Option = append(o.Option, &opt) + result = RewriteDone + } + + return result +} + +func isValidVariable(variable string) bool { + switch variable { + case + queryName, + queryType, + clientIP, + clientPort, + protocol, + serverIP, + serverPort: + return true + } + return false +} + +// ends0SubnetRule is a rewrite rule for EDNS0 subnet options +type edns0SubnetRule struct { + v4BitMaskLen uint8 + v6BitMaskLen uint8 + action string +} + +func newEdns0SubnetRule(action, v4BitMaskLen, v6BitMaskLen string) (*edns0SubnetRule, error) { + v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16) + if err != nil { + return nil, err + } + // Validate V4 length + if v4Len > maxV4BitMaskLen { + return nil, fmt.Errorf("invalid IPv4 bit mask length %d", v4Len) + } + + v6Len, err := strconv.ParseUint(v6BitMaskLen, 0, 16) + if err != nil { + return nil, err + } + //Validate V6 length + if v6Len > maxV6BitMaskLen { + return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len) + } + + return &edns0SubnetRule{action: action, + v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len)}, nil +} + +// fillEcsData sets the subnet data into the ecs option +func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, + ecs *dns.EDNS0_SUBNET) error { + + req := request.Request{W: w, Req: r} + family := req.Family() + if (family != 1) && (family != 2) { + return fmt.Errorf("unable to fill data for EDNS0 subnet due to invalid IP family") + } + + ecs.DraftOption = false + ecs.Family = uint16(family) + ecs.SourceScope = 0 + + ipAddr := req.IP() + switch family { + case 1: + ipv4Mask := net.CIDRMask(int(rule.v4BitMaskLen), 32) + ipv4Addr := net.ParseIP(ipAddr) + ecs.SourceNetmask = rule.v4BitMaskLen + ecs.Address = ipv4Addr.Mask(ipv4Mask).To4() + case 2: + ipv6Mask := net.CIDRMask(int(rule.v6BitMaskLen), 128) + ipv6Addr := net.ParseIP(ipAddr) + ecs.SourceNetmask = rule.v6BitMaskLen + ecs.Address = ipv6Addr.Mask(ipv6Mask).To16() + } + return nil +} + +// Rewrite will alter the request EDNS0 subnet option +func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + result := RewriteIgnored + o := setupEdns0Opt(r) + found := false + for _, s := range o.Option { + switch e := s.(type) { + case *dns.EDNS0_SUBNET: + if rule.action == Replace || rule.action == Set { + if rule.fillEcsData(w, r, e) == nil { + result = RewriteDone + } + } + found = true + break + } + } + + // add option if not found + if !found && (rule.action == Append || rule.action == Set) { + o.SetDo() + opt := dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET} + if rule.fillEcsData(w, r, &opt) == nil { + o.Option = append(o.Option, &opt) + result = RewriteDone + } + } + + return result +} + +// These are all defined actions. +const ( + Replace = "replace" + Set = "set" + Append = "append" +) + +// Supported local EDNS0 variables +const ( + queryName = "{qname}" + queryType = "{qtype}" + clientIP = "{client_ip}" + clientPort = "{client_port}" + protocol = "{protocol}" + serverIP = "{server_ip}" + serverPort = "{server_port}" +) + +// Subnet maximum bit mask length +const ( + maxV4BitMaskLen = 32 + maxV6BitMaskLen = 128 +) diff --git a/plugin/rewrite/name.go b/plugin/rewrite/name.go new file mode 100644 index 000000000..189133542 --- /dev/null +++ b/plugin/rewrite/name.go @@ -0,0 +1,24 @@ +package rewrite + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +type nameRule struct { + From, To string +} + +func newNameRule(from, to string) (Rule, error) { + return &nameRule{plugin.Name(from).Normalize(), plugin.Name(to).Normalize()}, nil +} + +// Rewrite rewrites the the current request. +func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + if rule.From == r.Question[0].Name { + r.Question[0].Name = rule.To + return RewriteDone + } + return RewriteIgnored +} diff --git a/plugin/rewrite/reverter.go b/plugin/rewrite/reverter.go new file mode 100644 index 000000000..400fb5fff --- /dev/null +++ b/plugin/rewrite/reverter.go @@ -0,0 +1,39 @@ +package rewrite + +import "github.com/miekg/dns" + +// ResponseReverter reverses the operations done on the question section of a packet. +// This is need because the client will otherwise disregards the response, i.e. +// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN' +type ResponseReverter struct { + dns.ResponseWriter + original dns.Question +} + +// NewResponseReverter returns a pointer to a new ResponseReverter. +func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter { + return &ResponseReverter{ + ResponseWriter: w, + original: r.Question[0], + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *ResponseReverter) WriteMsg(res *dns.Msg) error { + res.Question[0] = r.original + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *ResponseReverter) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + return n, err +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *ResponseReverter) Hijack() { + r.ResponseWriter.Hijack() + return +} diff --git a/plugin/rewrite/rewrite.go b/plugin/rewrite/rewrite.go new file mode 100644 index 000000000..d4931445c --- /dev/null +++ b/plugin/rewrite/rewrite.go @@ -0,0 +1,86 @@ +package rewrite + +import ( + "fmt" + "strings" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" + + "golang.org/x/net/context" +) + +// Result is the result of a rewrite +type Result int + +const ( + // RewriteIgnored is returned when rewrite is not done on request. + RewriteIgnored Result = iota + // RewriteDone is returned when rewrite is done on request. + RewriteDone + // RewriteStatus is returned when rewrite is not needed and status code should be set + // for the request. + RewriteStatus +) + +// Rewrite is plugin to rewrite requests internally before being handled. +type Rewrite struct { + Next plugin.Handler + Rules []Rule + noRevert bool +} + +// ServeDNS implements the plugin.Handler interface. +func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + wr := NewResponseReverter(w, r) + for _, rule := range rw.Rules { + switch result := rule.Rewrite(w, r); result { + case RewriteDone: + if rw.noRevert { + return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) + } + return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r) + case RewriteIgnored: + break + case RewriteStatus: + // only valid for complex rules. + // if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 { + // return cRule.Status, nil + // } + } + } + return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) +} + +// Name implements the Handler interface. +func (rw Rewrite) Name() string { return "rewrite" } + +// Rule describes a rewrite rule. +type Rule interface { + // Rewrite rewrites the current request. + Rewrite(dns.ResponseWriter, *dns.Msg) Result +} + +func newRule(args ...string) (Rule, error) { + if len(args) == 0 { + return nil, fmt.Errorf("no rule type specified for rewrite") + } + + ruleType := strings.ToLower(args[0]) + if ruleType != "edns0" && len(args) != 3 { + return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType) + } + switch ruleType { + case "name": + return newNameRule(args[1], args[2]) + case "class": + return newClassRule(args[1], args[2]) + case "type": + return newTypeRule(args[1], args[2]) + case "edns0": + return newEdns0Rule(args[1:]...) + default: + return nil, fmt.Errorf("invalid rule type %q", args[0]) + } +} diff --git a/plugin/rewrite/rewrite_test.go b/plugin/rewrite/rewrite_test.go new file mode 100644 index 000000000..74a8594df --- /dev/null +++ b/plugin/rewrite/rewrite_test.go @@ -0,0 +1,532 @@ +package rewrite + +import ( + "bytes" + "reflect" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnsrecorder" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +func msgPrinter(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + w.WriteMsg(r) + return 0, nil +} + +func TestNewRule(t *testing.T) { + tests := []struct { + args []string + shouldError bool + expType reflect.Type + }{ + {[]string{}, true, nil}, + {[]string{"foo"}, true, nil}, + {[]string{"name"}, true, nil}, + {[]string{"name", "a.com"}, true, nil}, + {[]string{"name", "a.com", "b.com", "c.com"}, true, nil}, + {[]string{"name", "a.com", "b.com"}, false, reflect.TypeOf(&nameRule{})}, + {[]string{"type"}, true, nil}, + {[]string{"type", "a"}, true, nil}, + {[]string{"type", "any", "a", "a"}, true, nil}, + {[]string{"type", "any", "a"}, false, reflect.TypeOf(&typeRule{})}, + {[]string{"type", "XY", "WV"}, true, nil}, + {[]string{"type", "ANY", "WV"}, true, nil}, + {[]string{"class"}, true, nil}, + {[]string{"class", "IN"}, true, nil}, + {[]string{"class", "ch", "in", "in"}, true, nil}, + {[]string{"class", "ch", "in"}, false, reflect.TypeOf(&classRule{})}, + {[]string{"class", "XY", "WV"}, true, nil}, + {[]string{"class", "IN", "WV"}, true, nil}, + {[]string{"edns0"}, true, nil}, + {[]string{"edns0", "local"}, true, nil}, + {[]string{"edns0", "local", "set"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee"}, true, nil}, + {[]string{"edns0", "local", "set", "65518", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "foo", "0xffee", "abcdefg"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "0xabcdefg"}, true, nil}, + {[]string{"edns0", "nsid", "set", "junk"}, true, nil}, + {[]string{"edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "foo"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "subnet", "set", "-1", "56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "-56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "33", "56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "129"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + } + + for i, tc := range tests { + r, err := newRule(tc.args...) + if err == nil && tc.shouldError { + t.Errorf("Test %d: expected error but got success", i) + } else if err != nil && !tc.shouldError { + t.Errorf("Test %d: expected success but got error: %s", i, err) + } + + if !tc.shouldError && reflect.TypeOf(r) != tc.expType { + t.Errorf("Test %d: expected %q but got %q", i, tc.expType, r) + } + } +} + +func TestRewrite(t *testing.T) { + rules := []Rule{} + r, _ := newNameRule("from.nl.", "to.nl.") + rules = append(rules, r) + r, _ = newClassRule("CH", "IN") + rules = append(rules, r) + r, _ = newTypeRule("ANY", "HINFO") + rules = append(rules, r) + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + noRevert: true, + } + + tests := []struct { + from string + fromT uint16 + fromC uint16 + to string + toT uint16 + toC uint16 + }{ + {"from.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"a.nl.", dns.TypeA, dns.ClassINET, "a.nl.", dns.TypeA, dns.ClassINET}, + {"a.nl.", dns.TypeA, dns.ClassCHAOS, "a.nl.", dns.TypeA, dns.ClassINET}, + {"a.nl.", dns.TypeANY, dns.ClassINET, "a.nl.", dns.TypeHINFO, dns.ClassINET}, + // name is rewritten, type is not. + {"from.nl.", dns.TypeANY, dns.ClassINET, "to.nl.", dns.TypeANY, dns.ClassINET}, + // name is not, type is, but class is, because class is the 2nd rule. + {"a.nl.", dns.TypeANY, dns.ClassCHAOS, "a.nl.", dns.TypeANY, dns.ClassINET}, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion(tc.from, tc.fromT) + m.Question[0].Qclass = tc.fromC + + rec := dnsrecorder.New(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + if resp.Question[0].Name != tc.to { + t.Errorf("Test %d: Expected Name to be %q but was %q", i, tc.to, resp.Question[0].Name) + } + if resp.Question[0].Qtype != tc.toT { + t.Errorf("Test %d: Expected Type to be '%d' but was '%d'", i, tc.toT, resp.Question[0].Qtype) + } + if resp.Question[0].Qclass != tc.toC { + t.Errorf("Test %d: Expected Class to be '%d' but was '%d'", i, tc.toC, resp.Question[0].Qclass) + } + } +} + +func TestRewriteEDNS0Local(t *testing.T) { + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + noRevert: true, + } + + tests := []struct { + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + }{ + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "0xabcdef"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0xab, 0xcd, 0xef}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "append", "0xffee", "abcdefghijklmnop"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("abcdefghijklmnop")}}, + }, + { + []dns.EDNS0{}, + []string{"local", "replace", "0xffee", "abcdefghijklmnop"}, + []dns.EDNS0{}, + }, + { + []dns.EDNS0{}, + []string{"nsid", "set"}, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + }, + { + []dns.EDNS0{}, + []string{"nsid", "append"}, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + }, + { + []dns.EDNS0{}, + []string{"nsid", "replace"}, + []dns.EDNS0{}, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + r, err := newEdns0Rule(tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + + rec := dnsrecorder.New(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} + +func TestEdns0LocalMultiRule(t *testing.T) { + rules := []Rule{} + r, _ := newEdns0Rule("local", "replace", "0xffee", "abcdef") + rules = append(rules, r) + r, _ = newEdns0Rule("local", "set", "0xffee", "fedcba") + rules = append(rules, r) + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + noRevert: true, + } + + tests := []struct { + fromOpts []dns.EDNS0 + toOpts []dns.EDNS0 + }{ + { + nil, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("fedcba")}}, + }, + { + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("abcdef")}}, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + if tc.fromOpts != nil { + o := m.IsEdns0() + if o == nil { + m.SetEdns0(4096, true) + o = m.IsEdns0() + } + o.Option = append(o.Option, tc.fromOpts...) + } + rec := dnsrecorder.New(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} + +func optsEqual(a, b []dns.EDNS0) bool { + if len(a) != len(b) { + return false + } + for i := range a { + switch aa := a[i].(type) { + case *dns.EDNS0_LOCAL: + if bb, ok := b[i].(*dns.EDNS0_LOCAL); ok { + if aa.Code != bb.Code { + return false + } + if !bytes.Equal(aa.Data, bb.Data) { + return false + } + } else { + return false + } + case *dns.EDNS0_NSID: + if bb, ok := b[i].(*dns.EDNS0_NSID); ok { + if aa.Nsid != bb.Nsid { + return false + } + } else { + return false + } + case *dns.EDNS0_SUBNET: + if bb, ok := b[i].(*dns.EDNS0_SUBNET); ok { + if aa.Code != bb.Code { + return false + } + if aa.Family != bb.Family { + return false + } + if aa.SourceNetmask != bb.SourceNetmask { + return false + } + if aa.SourceScope != bb.SourceScope { + return false + } + if !bytes.Equal(aa.Address, bb.Address) { + return false + } + if aa.DraftOption != bb.DraftOption { + return false + } + } else { + return false + } + + default: + return false + } + } + return true +} + +func TestRewriteEDNS0LocalVariable(t *testing.T) { + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + noRevert: true, + } + + // test.ResponseWriter has the following values: + // The remote will always be 10.240.0.1 and port 40212. + // The local address is always 127.0.0.1 and port 53. + + tests := []struct { + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + }{ + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{qname}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("example.com.")}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{qtype}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x01}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{client_ip}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x0A, 0xF0, 0x00, 0x01}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{client_port}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x9D, 0x14}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{protocol}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("udp")}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{server_ip}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x7F, 0x00, 0x00, 0x01}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{server_port}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x35}}}, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + r, err := newEdns0Rule(tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + + rec := dnsrecorder.New(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} + +func TestRewriteEDNS0Subnet(t *testing.T) { + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + noRevert: true, + } + + tests := []struct { + writer dns.ResponseWriter + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + }{ + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x18, + SourceScope: 0x0, + Address: []byte{0x0A, 0xF0, 0x00, 0x00}, + DraftOption: false}}, + }, + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "32", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x20, + SourceScope: 0x0, + Address: []byte{0x0A, 0xF0, 0x00, 0x01}, + DraftOption: false}}, + }, + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "0", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00}, + DraftOption: false}}, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x38, + SourceScope: 0x0, + Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + DraftOption: false}}, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "128"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x80, + SourceScope: 0x0, + Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x42, 0x00, 0xff, 0xfe, 0xca, 0x4c, 0x65}, + DraftOption: false}}, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "0"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + DraftOption: false}}, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + r, err := newEdns0Rule(tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + rec := dnsrecorder.New(tc.writer) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} diff --git a/plugin/rewrite/setup.go b/plugin/rewrite/setup.go new file mode 100644 index 000000000..5954a3300 --- /dev/null +++ b/plugin/rewrite/setup.go @@ -0,0 +1,42 @@ +package rewrite + +import ( + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("rewrite", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + rewrites, err := rewriteParse(c) + if err != nil { + return plugin.Error("rewrite", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Rewrite{Next: next, Rules: rewrites} + }) + + return nil +} + +func rewriteParse(c *caddy.Controller) ([]Rule, error) { + var rules []Rule + + for c.Next() { + args := c.RemainingArgs() + rule, err := newRule(args...) + if err != nil { + return nil, err + } + rules = append(rules, rule) + } + return rules, nil +} diff --git a/plugin/rewrite/setup_test.go b/plugin/rewrite/setup_test.go new file mode 100644 index 000000000..67ef88e18 --- /dev/null +++ b/plugin/rewrite/setup_test.go @@ -0,0 +1,25 @@ +package rewrite + +import ( + "testing" + + "github.com/mholt/caddy" +) + +func TestParse(t *testing.T) { + c := caddy.NewTestController("dns", `rewrite`) + _, err := rewriteParse(c) + if err == nil { + t.Errorf("Expected error but found nil for `rewrite`") + } + c = caddy.NewTestController("dns", `rewrite name`) + _, err = rewriteParse(c) + if err == nil { + t.Errorf("Expected error but found nil for `rewrite name`") + } + c = caddy.NewTestController("dns", `rewrite name a.com b.com`) + _, err = rewriteParse(c) + if err != nil { + t.Errorf("Expected success but found %s for `rewrite name a.com b.com`", err) + } +} diff --git a/plugin/rewrite/testdata/testdir/empty b/plugin/rewrite/testdata/testdir/empty new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/plugin/rewrite/testdata/testdir/empty diff --git a/plugin/rewrite/testdata/testfile b/plugin/rewrite/testdata/testfile new file mode 100644 index 000000000..7b4d68d70 --- /dev/null +++ b/plugin/rewrite/testdata/testfile @@ -0,0 +1 @@ +empty
\ No newline at end of file diff --git a/plugin/rewrite/type.go b/plugin/rewrite/type.go new file mode 100644 index 000000000..ae3efcc5a --- /dev/null +++ b/plugin/rewrite/type.go @@ -0,0 +1,37 @@ +// Package rewrite is plugin for rewriting requests internally to something different. +package rewrite + +import ( + "fmt" + "strings" + + "github.com/miekg/dns" +) + +// typeRule is a type rewrite rule. +type typeRule struct { + fromType, toType uint16 +} + +func newTypeRule(fromS, toS string) (Rule, error) { + var from, to uint16 + var ok bool + if from, ok = dns.StringToType[strings.ToUpper(fromS)]; !ok { + return nil, fmt.Errorf("invalid type %q", strings.ToUpper(fromS)) + } + if to, ok = dns.StringToType[strings.ToUpper(toS)]; !ok { + return nil, fmt.Errorf("invalid type %q", strings.ToUpper(toS)) + } + return &typeRule{fromType: from, toType: to}, nil +} + +// Rewrite rewrites the the current request. +func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + if rule.fromType > 0 && rule.toType > 0 { + if r.Question[0].Qtype == rule.fromType { + r.Question[0].Qtype = rule.toType + return RewriteDone + } + } + return RewriteIgnored +} |