diff options
Diffstat (limited to 'plugin')
-rw-r--r-- | plugin/metadata/metadata.go | 23 | ||||
-rw-r--r-- | plugin/metadata/metadata_test.go | 52 | ||||
-rw-r--r-- | plugin/metadata/provider.go | 102 | ||||
-rw-r--r-- | plugin/metadata/provider_test.go | 48 | ||||
-rw-r--r-- | plugin/metadata/setup.go | 10 | ||||
-rw-r--r-- | plugin/pkg/variables/variables.go | 104 | ||||
-rw-r--r-- | plugin/pkg/variables/variables_test.go | 83 | ||||
-rw-r--r-- | plugin/rewrite/README.md | 10 | ||||
-rw-r--r-- | plugin/rewrite/class.go | 3 | ||||
-rw-r--r-- | plugin/rewrite/edns0.go | 135 | ||||
-rw-r--r-- | plugin/rewrite/name.go | 11 | ||||
-rw-r--r-- | plugin/rewrite/rewrite.go | 4 | ||||
-rw-r--r-- | plugin/rewrite/rewrite_test.go | 6 | ||||
-rw-r--r-- | plugin/rewrite/type.go | 3 |
14 files changed, 227 insertions, 367 deletions
diff --git a/plugin/metadata/metadata.go b/plugin/metadata/metadata.go index e7560d403..4abe57ddf 100644 --- a/plugin/metadata/metadata.go +++ b/plugin/metadata/metadata.go @@ -4,7 +4,6 @@ import ( "context" "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/pkg/variables" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -24,18 +23,13 @@ func (m *Metadata) Name() string { return "metadata" } // ServeDNS implements the plugin.Handler interface. func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - ctx = context.WithValue(ctx, metadataKey{}, M{}) - md, _ := FromContext(ctx) + ctx = context.WithValue(ctx, key{}, md{}) state := request.Request{W: w, Req: r} if plugin.Zones(m.Zones).Matches(state.Name()) != "" { // Go through all Providers and collect metadata. - for _, provider := range m.Providers { - for _, varName := range provider.MetadataVarNames() { - if val, ok := provider.Metadata(ctx, state, varName); ok { - md.SetValue(varName, val) - } - } + for _, p := range m.Providers { + ctx = p.Metadata(ctx, state) } } @@ -43,14 +37,3 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms return rcode, err } - -// MetadataVarNames implements the plugin.Provider interface. -func (m *Metadata) MetadataVarNames() []string { return variables.All } - -// Metadata implements the plugin.Provider interface. -func (m *Metadata) Metadata(ctx context.Context, state request.Request, varName string) (interface{}, bool) { - if val, err := variables.GetValue(state, varName); err == nil { - return val, true - } - return nil, false -} diff --git a/plugin/metadata/metadata_test.go b/plugin/metadata/metadata_test.go index 8bbff4c34..7ded05c03 100644 --- a/plugin/metadata/metadata_test.go +++ b/plugin/metadata/metadata_test.go @@ -10,26 +10,18 @@ import ( "github.com/miekg/dns" ) -// testProvider implements fake Providers. Plugins which inmplement Provider interface -type testProvider map[string]interface{} +type testProvider map[string]Func -func (m testProvider) MetadataVarNames() []string { - keys := []string{} - for k := range m { - keys = append(keys, k) +func (tp testProvider) Metadata(ctx context.Context, state request.Request) context.Context { + for k, v := range tp { + SetValueFunc(ctx, k, v) } - return keys + return ctx } -func (m testProvider) Metadata(ctx context.Context, state request.Request, key string) (val interface{}, ok bool) { - value, ok := m[key] - return value, ok -} - -// testHandler implements plugin.Handler. type testHandler struct{ ctx context.Context } -func (m *testHandler) Name() string { return "testHandler" } +func (m *testHandler) Name() string { return "test" } func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { m.ctx = ctx @@ -38,8 +30,8 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns func TestMetadataServeDNS(t *testing.T) { expectedMetadata := []testProvider{ - testProvider{"testkey1": "testvalue1"}, - testProvider{"testkey2": 2, "testkey3": "testvalue3"}, + testProvider{"test/key1": func() string { return "testvalue1" }}, + testProvider{"test/key2": func() string { return "two" }, "test/key3": func() string { return "testvalue3" }}, } // Create fake Providers based on expectedMetadata providers := []Provider{} @@ -48,32 +40,22 @@ func TestMetadataServeDNS(t *testing.T) { } next := &testHandler{} // fake handler which stores the resulting context - metadata := Metadata{ + m := Metadata{ Zones: []string{"."}, Providers: providers, Next: next, } - metadata.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg)) - // Verify that next plugin can find metadata in context from all Providers + ctx := context.TODO() + m.ServeDNS(ctx, &test.ResponseWriter{}, new(dns.Msg)) + nctx := next.ctx + for _, expected := range expectedMetadata { - md, ok := FromContext(next.ctx) - if !ok { - t.Fatalf("Metadata is expected but not present inside the context") - } - for expKey, expVal := range expected { - metadataVal, valOk := md.Value(expKey) - if !valOk { - t.Fatalf("Value by key %v can't be retrieved", expKey) - } - if metadataVal != expVal { - t.Errorf("Expected value %v, but got %v", expVal, metadataVal) + for label, expVal := range expected { + val := ValueFunc(nctx, label) + if val() != expVal() { + t.Errorf("Expected value %s for %s, but got %s", expVal(), label, val()) } } - wrongKey := "wrong_key" - metadataVal, ok := md.Value(wrongKey) - if ok { - t.Fatalf("Value by key %v is not expected to be recieved, but got: %v", wrongKey, metadataVal) - } } } diff --git a/plugin/metadata/provider.go b/plugin/metadata/provider.go index e13f9c896..eb7bb9755 100644 --- a/plugin/metadata/provider.go +++ b/plugin/metadata/provider.go @@ -1,3 +1,33 @@ +// Package metadata provides an API that allows plugins to add metadata to the context. +// Each metadata is stored under a label that has the form <plugin>/<name>. Each metadata +// is returned as a Func. When Func is called the metadata is returned. If Func is expensive to +// execute it is its responsibility to provide some form of caching. During the handling of a +// query it is expected the metadata stays constant. +// +// Basic example: +// +// Implement the Provder interface for a plugin: +// +// func (p P) Metadata(ctx context.Context, state request.Request) context.Context { +// cached := "" +// f := func() string { +// if cached != "" { +// return cached +// } +// cached = expensiveFunc() +// return cached +// } +// metadata.SetValueFunc(ctx, "test/something", f) +// return ctx +// } +// +// Check the metadata from another plugin: +// +// // ... +// valueFunc := metadata.ValueFunc(ctx, "test/something") +// value := valueFunc() +// // use 'value' +// package metadata import ( @@ -8,40 +38,62 @@ import ( // Provider interface needs to be implemented by each plugin willing to provide // metadata information for other plugins. -// Note: this method should work quickly, because it is called for every request -// from the metadata plugin. type Provider interface { - // List of variables which are provided by current Provider. Must remain constant. - MetadataVarNames() []string - // Metadata is expected to return a value with metadata information by the key - // from 4th argument. Value can be later retrieved from context by any other plugin. - // If value is not available by some reason returned boolean value should be false. - Metadata(ctx context.Context, state request.Request, variable string) (interface{}, bool) + // Metadata adds metadata to the context and returns a (potentially) new context. + // Note: this method should work quickly, because it is called for every request + // from the metadata plugin. + Metadata(ctx context.Context, state request.Request) context.Context } -// M is metadata information storage. -type M map[string]interface{} +// Func is the type of function in the metadata, when called they return the value of the label. +type Func func() string -// FromContext retrieves the metadata from the context. -func FromContext(ctx context.Context) (M, bool) { - if metadata := ctx.Value(metadataKey{}); metadata != nil { - if m, ok := metadata.(M); ok { - return m, true +// Labels returns all metadata keys stored in the context. These label names should be named +// as: plugin/NAME, where NAME is something descriptive. +func Labels(ctx context.Context) []string { + if metadata := ctx.Value(key{}); metadata != nil { + if m, ok := metadata.(md); ok { + return keys(m) } } - return M{}, false + return nil } -// Value returns metadata value by key. -func (m M) Value(key string) (value interface{}, ok bool) { - value, ok = m[key] - return value, ok +// ValueFunc returns the value function of label. If none can be found nil is returned. Calling the +// function returns the value of the label. +func ValueFunc(ctx context.Context, label string) Func { + if metadata := ctx.Value(key{}); metadata != nil { + if m, ok := metadata.(md); ok { + return m[label] + } + } + return nil } -// SetValue sets the metadata value under key. -func (m M) SetValue(key string, val interface{}) { - m[key] = val +// SetValueFunc set the metadata label to the value function. If no metadata can be found this is a noop and +// false is returned. Any existing value is overwritten. +func SetValueFunc(ctx context.Context, label string, f Func) bool { + if metadata := ctx.Value(key{}); metadata != nil { + if m, ok := metadata.(md); ok { + m[label] = f + return true + } + } + return false } -// metadataKey defines the type of key that is used to save metadata into the context. -type metadataKey struct{} +// md is metadata information storage. +type md map[string]Func + +// key defines the type of key that is used to save metadata into the context. +type key struct{} + +func keys(m map[string]Func) []string { + s := make([]string, len(m)) + i := 0 + for k := range m { + s[i] = k + i++ + } + return s +} diff --git a/plugin/metadata/provider_test.go b/plugin/metadata/provider_test.go deleted file mode 100644 index 1a074aeaa..000000000 --- a/plugin/metadata/provider_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package metadata - -import ( - "context" - "reflect" - "testing" -) - -func TestMD(t *testing.T) { - tests := []struct { - addValues map[string]interface{} - expectedValues map[string]interface{} - }{ - { - // Add initial metadata key/vals - map[string]interface{}{"key1": "val1", "key2": 2}, - map[string]interface{}{"key1": "val1", "key2": 2}, - }, - { - // Add additional key/vals. - map[string]interface{}{"key3": 3, "key4": 4.5}, - map[string]interface{}{"key1": "val1", "key2": 2, "key3": 3, "key4": 4.5}, - }, - } - - // Using one same md and ctx for all test cases - ctx := context.TODO() - ctx = context.WithValue(ctx, metadataKey{}, M{}) - m, _ := FromContext(ctx) - - for i, tc := range tests { - for k, v := range tc.addValues { - m.SetValue(k, v) - } - if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(m)) { - t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, m) - } - - // Make sure that md is recieved from context successfullly - mFromContext, ok := FromContext(ctx) - if !ok { - t.Errorf("Test %d: md is not recieved from the context", i) - } - if !reflect.DeepEqual(m, mFromContext) { - t.Errorf("Test %d: md recieved from context differs from initial. Initial: %v, from context: %v", i, m, mFromContext) - } - } -} diff --git a/plugin/metadata/setup.go b/plugin/metadata/setup.go index 33a153a2c..282bcf7d9 100644 --- a/plugin/metadata/setup.go +++ b/plugin/metadata/setup.go @@ -1,8 +1,6 @@ package metadata import ( - "fmt" - "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" @@ -28,16 +26,8 @@ func setup(c *caddy.Controller) error { c.OnStartup(func() error { plugins := dnsserver.GetConfig(c).Handlers() - // Collect all plugins which implement Provider interface - metadataVariables := map[string]bool{} for _, p := range plugins { if met, ok := p.(Provider); ok { - for _, varName := range met.MetadataVarNames() { - if _, ok := metadataVariables[varName]; ok { - return fmt.Errorf("Metadata variable '%v' has duplicates", varName) - } - metadataVariables[varName] = true - } m.Providers = append(m.Providers, met) } } diff --git a/plugin/pkg/variables/variables.go b/plugin/pkg/variables/variables.go deleted file mode 100644 index 8e1cdbe77..000000000 --- a/plugin/pkg/variables/variables.go +++ /dev/null @@ -1,104 +0,0 @@ -package variables - -import ( - "encoding/binary" - "fmt" - "net" - "strconv" - - "github.com/coredns/coredns/request" -) - -const ( - queryName = "qname" - queryType = "qtype" - clientIP = "client_ip" - clientPort = "client_port" - protocol = "protocol" - serverIP = "server_ip" - serverPort = "server_port" -) - -// All is a list of available variables provided by GetMetadataValue -var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverIP, serverPort} - -// GetValue calculates and returns the data specified by the variable name. -// Supported varNames are listed in allProvidedVars. -func GetValue(state request.Request, varName string) ([]byte, error) { - switch varName { - case queryName: - return []byte(state.QName()), nil - - case queryType: - return uint16ToWire(state.QType()), nil - - case clientIP: - return ipToWire(state.Family(), state.IP()) - - case clientPort: - return portToWire(state.Port()) - - case protocol: - return []byte(state.Proto()), nil - - case serverIP: - ip, _, err := net.SplitHostPort(state.W.LocalAddr().String()) - if err != nil { - ip = state.W.RemoteAddr().String() - } - return ipToWire(state.Family(), ip) - - case serverPort: - _, port, err := net.SplitHostPort(state.W.LocalAddr().String()) - if err != nil { - port = "0" - } - return portToWire(port) - } - - return nil, fmt.Errorf("unable to extract data for variable %s", varName) -} - -// uint16ToWire writes unit16 to wire/binary format -func uint16ToWire(data uint16) []byte { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, uint16(data)) - return buf -} - -// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. -func ipToWire(family int, ipAddr string) ([]byte, error) { - - switch family { - case 1: - return net.ParseIP(ipAddr).To4(), nil - case 2: - return net.ParseIP(ipAddr).To16(), nil - } - return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family) -} - -// portToWire writes port to wire/binary format, 2 bytes -func portToWire(portStr string) ([]byte, error) { - - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, err - } - return uint16ToWire(uint16(port)), nil -} - -// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. -func family(ip net.Addr) int { - var a net.IP - if i, ok := ip.(*net.UDPAddr); ok { - a = i.IP - } - if i, ok := ip.(*net.TCPAddr); ok { - a = i.IP - } - if a.To4() != nil { - return 1 - } - return 2 -} diff --git a/plugin/pkg/variables/variables_test.go b/plugin/pkg/variables/variables_test.go deleted file mode 100644 index e0ff64c19..000000000 --- a/plugin/pkg/variables/variables_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package variables - -import ( - "bytes" - "testing" - - "github.com/coredns/coredns/plugin/test" - "github.com/coredns/coredns/request" - - "github.com/miekg/dns" -) - -func TestGetValue(t *testing.T) { - // test.ResponseWriter has the following values: - // The remote will always be 10.240.0.1 and port 40212. - // The local address is always 127.0.0.1 and port 53. - tests := []struct { - varName string - expectedValue []byte - shouldErr bool - }{ - { - queryName, - []byte("example.com."), - false, - }, - { - queryType, - []byte{0x00, 0x01}, - false, - }, - { - clientIP, - []byte{10, 240, 0, 1}, - false, - }, - { - clientPort, - []byte{0x9D, 0x14}, - false, - }, - { - protocol, - []byte("udp"), - false, - }, - { - serverIP, - []byte{127, 0, 0, 1}, - false, - }, - { - serverPort, - []byte{0, 53}, - false, - }, - { - "wrong_var", - []byte{}, - true, - }, - } - - for i, tc := range tests { - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.Question[0].Qclass = dns.ClassINET - state := request.Request{W: &test.ResponseWriter{}, Req: m} - - value, err := GetValue(state, tc.varName) - - if tc.shouldErr && err == nil { - t.Errorf("Test %d: Expected error, but didn't recieve", i) - } - if !tc.shouldErr && err != nil { - t.Errorf("Test %d: Expected no error, but got error: %v", i, err.Error()) - } - - if !bytes.Equal(tc.expectedValue, value) { - t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValue, value) - } - } -} diff --git a/plugin/rewrite/README.md b/plugin/rewrite/README.md index 4e2e49a3a..680e69722 100644 --- a/plugin/rewrite/README.md +++ b/plugin/rewrite/README.md @@ -206,17 +206,13 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq } ~~~ -* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables by default: +* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables: {qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}. -Any plugin that can provide it's own additional variables by implementing metadata.Provider interface. If you are going to use metadata variables then metadata plugin must be enabled. Example: -~~~ corefile -. { - metadata - rewrite edns0 local set 0xffee {client_ip} -} +~~~ +rewrite edns0 local set 0xffee {client_ip} ~~~ ### EDNS0_NSID diff --git a/plugin/rewrite/class.go b/plugin/rewrite/class.go index b04dabce2..2e54f515c 100644 --- a/plugin/rewrite/class.go +++ b/plugin/rewrite/class.go @@ -1,7 +1,6 @@ package rewrite import ( - "context" "fmt" "strings" @@ -28,7 +27,7 @@ func newClassRule(nextAction string, args ...string) (Rule, error) { } // Rewrite rewrites the the current request. -func (rule *classRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if rule.fromClass > 0 && rule.toClass > 0 { if r.Question[0].Qclass == rule.fromClass { r.Question[0].Qclass = rule.toClass diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go index f59cee1e0..a651744d2 100644 --- a/plugin/rewrite/edns0.go +++ b/plugin/rewrite/edns0.go @@ -2,15 +2,13 @@ package rewrite import ( - "context" + "encoding/binary" "encoding/hex" "fmt" "net" "strconv" "strings" - "github.com/coredns/coredns/plugin/metadata" - "github.com/coredns/coredns/plugin/pkg/variables" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -49,7 +47,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT { } // Rewrite will alter the request EDNS0 NSID option -func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -86,7 +84,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule { } // Rewrite will alter the request EDNS0 local options -func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -149,9 +147,7 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) { } //Check for variable option if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { - // Remove first and last runes - variable := args[3][1 : len(args[3])-1] - return newEdns0VariableRule(mode, action, args[2], variable) + return newEdns0VariableRule(mode, action, args[2], args[3]) } return newEdns0LocalRule(mode, action, args[2], args[3]) case "nsid": @@ -191,29 +187,102 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu if err != nil { return nil, err } + //Validate + if !isValidVariable(variable) { + return nil, fmt.Errorf("unsupported variable name %q", variable) + } return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil } +// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. +func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) { + + switch family { + case 1: + return net.ParseIP(ipAddr).To4(), nil + case 2: + return net.ParseIP(ipAddr).To16(), nil + } + return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family) +} + +// uint16ToWire writes unit16 to wire/binary format +func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(data)) + return buf +} + +// portToWire writes port to wire/binary format, 2 bytes +func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) { + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + return rule.uint16ToWire(uint16(port)), nil +} + +// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. +func (rule *edns0VariableRule) family(ip net.Addr) int { + var a net.IP + if i, ok := ip.(*net.UDPAddr); ok { + a = i.IP + } + if i, ok := ip.(*net.TCPAddr); ok { + a = i.IP + } + if a.To4() != nil { + return 1 + } + return 2 +} + // ruleData returns the data specified by the variable -func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { - if md, ok := metadata.FromContext(ctx); ok { - if value, ok := md.Value(rule.variable); ok { - if v, ok := value.([]byte); ok { - return v, nil - } +func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { + + req := request.Request{W: w, Req: r} + switch rule.variable { + case queryName: + //Query name is written as ascii string + return []byte(req.QName()), nil + + case queryType: + return rule.uint16ToWire(req.QType()), nil + + case clientIP: + return rule.ipToWire(req.Family(), req.IP()) + + case clientPort: + return rule.portToWire(req.Port()) + + case protocol: + // Proto is written as ascii string + return []byte(req.Proto()), nil + + case serverIP: + ip, _, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + ip = w.RemoteAddr().String() } - } else { // No metadata available means metadata plugin is disabled. Try to get the value directly. - state := request.Request{W: w, Req: r} // TODO(miek): every rule needs to take a request.Request. - return variables.GetValue(state, rule.variable) + return rule.ipToWire(rule.family(w.RemoteAddr()), ip) + + case serverPort: + _, port, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + port = "0" + } + return rule.portToWire(port) } + return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable) } // Rewrite will alter the request EDNS0 local options with specified variables -func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored - data, err := rule.ruleData(ctx, w, r) + data, err := rule.ruleData(w, r) if err != nil || data == nil { return result } @@ -256,6 +325,21 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule { return ResponseRule{} } +func isValidVariable(variable string) bool { + switch variable { + case + queryName, + queryType, + clientIP, + clientPort, + protocol, + serverIP, + serverPort: + return true + } + return false +} + // ends0SubnetRule is a rewrite rule for EDNS0 subnet options type edns0SubnetRule struct { mode string @@ -316,7 +400,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, ecs * } // Rewrite will alter the request EDNS0 subnet option -func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -362,6 +446,17 @@ const ( Append = "append" ) +// Supported local EDNS0 variables +const ( + queryName = "{qname}" + queryType = "{qtype}" + clientIP = "{client_ip}" + clientPort = "{client_port}" + protocol = "{protocol}" + serverIP = "{server_ip}" + serverPort = "{server_port}" +) + // Subnet maximum bit mask length const ( maxV4BitMaskLen = 32 diff --git a/plugin/rewrite/name.go b/plugin/rewrite/name.go index e06c39078..eb2ac7285 100644 --- a/plugin/rewrite/name.go +++ b/plugin/rewrite/name.go @@ -1,7 +1,6 @@ package rewrite import ( - "context" "fmt" "regexp" "strconv" @@ -58,7 +57,7 @@ const ( // Rewrite rewrites the current request based upon exact match of the name // in the question section of the request -func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if rule.From == r.Question[0].Name { r.Question[0].Name = rule.To return RewriteDone @@ -67,7 +66,7 @@ func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns. } // Rewrite rewrites the current request when the name begins with the matching string -func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if strings.HasPrefix(r.Question[0].Name, rule.Prefix) { r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix) return RewriteDone @@ -76,7 +75,7 @@ func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r } // Rewrite rewrites the current request when the name ends with the matching string -func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if strings.HasSuffix(r.Question[0].Name, rule.Suffix) { r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement return RewriteDone @@ -86,7 +85,7 @@ func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r // Rewrite rewrites the current request based upon partial match of the // name in the question section of the request -func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if strings.Contains(r.Question[0].Name, rule.Substring) { r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1) return RewriteDone @@ -96,7 +95,7 @@ func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter // Rewrite rewrites the current request when the name in the question // section of the request matches a regular expression -func (rule *regexNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *regexNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name) if len(regexGroups) == 0 { return RewriteIgnored diff --git a/plugin/rewrite/rewrite.go b/plugin/rewrite/rewrite.go index 3ec58d32c..e340fa3ca 100644 --- a/plugin/rewrite/rewrite.go +++ b/plugin/rewrite/rewrite.go @@ -39,7 +39,7 @@ type Rewrite struct { func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { wr := NewResponseReverter(w, r) for _, rule := range rw.Rules { - switch result := rule.Rewrite(ctx, w, r); result { + switch result := rule.Rewrite(w, r); result { case RewriteDone: respRule := rule.GetResponseRule() if respRule.Active == true { @@ -68,7 +68,7 @@ func (rw Rewrite) Name() string { return "rewrite" } // Rule describes a rewrite rule. type Rule interface { // Rewrite rewrites the current request. - Rewrite(context.Context, dns.ResponseWriter, *dns.Msg) Result + Rewrite(dns.ResponseWriter, *dns.Msg) Result // Mode returns the processing mode stop or continue. Mode() string // GetResponseRule returns the rule to rewrite response with, if any. diff --git a/plugin/rewrite/rewrite_test.go b/plugin/rewrite/rewrite_test.go index b35543b9b..56c446f49 100644 --- a/plugin/rewrite/rewrite_test.go +++ b/plugin/rewrite/rewrite_test.go @@ -71,7 +71,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "foo"}, true, nil}, - {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil}, {[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, @@ -79,7 +79,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, - {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil}, {[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, @@ -87,7 +87,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, - {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil}, {[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, diff --git a/plugin/rewrite/type.go b/plugin/rewrite/type.go index c5c545485..ec36b0b0a 100644 --- a/plugin/rewrite/type.go +++ b/plugin/rewrite/type.go @@ -2,7 +2,6 @@ package rewrite import ( - "context" "fmt" "strings" @@ -29,7 +28,7 @@ func newTypeRule(nextAction string, args ...string) (Rule, error) { } // Rewrite rewrites the the current request. -func (rule *typeRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if rule.fromType > 0 && rule.toType > 0 { if r.Question[0].Qtype == rule.fromType { r.Question[0].Qtype = rule.toType |