diff options
Diffstat (limited to 'plugin/rewrite/edns0.go')
-rw-r--r-- | plugin/rewrite/edns0.go | 134 |
1 files changed, 19 insertions, 115 deletions
diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go index 2fd42cb67..f8b65d468 100644 --- a/plugin/rewrite/edns0.go +++ b/plugin/rewrite/edns0.go @@ -2,13 +2,15 @@ package rewrite import ( - "encoding/binary" + "context" "encoding/hex" "fmt" "net" "strconv" "strings" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/variables" "github.com/coredns/coredns/request" "github.com/miekg/dns" ) @@ -46,7 +48,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT { } // Rewrite will alter the request EDNS0 NSID option -func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -83,7 +85,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule { } // Rewrite will alter the request EDNS0 local options -func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -146,7 +148,9 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) { } //Check for variable option if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { - return newEdns0VariableRule(mode, action, args[2], args[3]) + // Remove first and last runes + variable := args[3][1 : len(args[3])-1] + return newEdns0VariableRule(mode, action, args[2], variable) } return newEdns0LocalRule(mode, action, args[2], args[3]) case "nsid": @@ -186,102 +190,28 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu if err != nil { return nil, err } - //Validate - if !isValidVariable(variable) { - return nil, fmt.Errorf("unsupported variable name %q", variable) - } return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil } -// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. -func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) { - - switch family { - case 1: - return net.ParseIP(ipAddr).To4(), nil - case 2: - return net.ParseIP(ipAddr).To16(), nil - } - return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family) -} - -// uint16ToWire writes unit16 to wire/binary format -func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, uint16(data)) - return buf -} - -// portToWire writes port to wire/binary format, 2 bytes -func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) { - - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, err - } - return rule.uint16ToWire(uint16(port)), nil -} - -// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. -func (rule *edns0VariableRule) family(ip net.Addr) int { - var a net.IP - if i, ok := ip.(*net.UDPAddr); ok { - a = i.IP - } - if i, ok := ip.(*net.TCPAddr); ok { - a = i.IP - } - if a.To4() != nil { - return 1 - } - return 2 -} - // ruleData returns the data specified by the variable -func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { - - req := request.Request{W: w, Req: r} - switch rule.variable { - case queryName: - //Query name is written as ascii string - return []byte(req.QName()), nil - - case queryType: - return rule.uint16ToWire(req.QType()), nil - - case clientIP: - return rule.ipToWire(req.Family(), req.IP()) - - case clientPort: - return rule.portToWire(req.Port()) - - case protocol: - // Proto is written as ascii string - return []byte(req.Proto()), nil - - case serverIP: - ip, _, err := net.SplitHostPort(w.LocalAddr().String()) - if err != nil { - ip = w.RemoteAddr().String() - } - return rule.ipToWire(rule.family(w.RemoteAddr()), ip) - - case serverPort: - _, port, err := net.SplitHostPort(w.LocalAddr().String()) - if err != nil { - port = "0" +func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { + if md, ok := metadata.FromContext(ctx); ok { + if value, ok := md.Value(rule.variable); ok { + if v, ok := value.([]byte); ok { + return v, nil + } } - return rule.portToWire(port) + } else { // No metadata available means metadata plugin is disabled. Try to get the value directly. + return variables.GetValue(rule.variable, w, r) } - return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable) } // Rewrite will alter the request EDNS0 local options with specified variables -func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored - data, err := rule.ruleData(w, r) + data, err := rule.ruleData(ctx, w, r) if err != nil || data == nil { return result } @@ -324,21 +254,6 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule { return ResponseRule{} } -func isValidVariable(variable string) bool { - switch variable { - case - queryName, - queryType, - clientIP, - clientPort, - protocol, - serverIP, - serverPort: - return true - } - return false -} - // ends0SubnetRule is a rewrite rule for EDNS0 subnet options type edns0SubnetRule struct { mode string @@ -400,7 +315,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, } // Rewrite will alter the request EDNS0 subnet option -func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -446,17 +361,6 @@ const ( Append = "append" ) -// Supported local EDNS0 variables -const ( - queryName = "{qname}" - queryType = "{qtype}" - clientIP = "{client_ip}" - clientPort = "{client_port}" - protocol = "{protocol}" - serverIP = "{server_ip}" - serverPort = "{server_port}" -) - // Subnet maximum bit mask length const ( maxV4BitMaskLen = 32 |