aboutsummaryrefslogtreecommitdiff
path: root/plugin
diff options
context:
space:
mode:
Diffstat (limited to 'plugin')
-rw-r--r--plugin/metadata/metadata.go23
-rw-r--r--plugin/metadata/metadata_test.go52
-rw-r--r--plugin/metadata/provider.go102
-rw-r--r--plugin/metadata/provider_test.go48
-rw-r--r--plugin/metadata/setup.go10
-rw-r--r--plugin/pkg/variables/variables.go104
-rw-r--r--plugin/pkg/variables/variables_test.go83
-rw-r--r--plugin/rewrite/README.md10
-rw-r--r--plugin/rewrite/class.go3
-rw-r--r--plugin/rewrite/edns0.go135
-rw-r--r--plugin/rewrite/name.go11
-rw-r--r--plugin/rewrite/rewrite.go4
-rw-r--r--plugin/rewrite/rewrite_test.go6
-rw-r--r--plugin/rewrite/type.go3
14 files changed, 227 insertions, 367 deletions
diff --git a/plugin/metadata/metadata.go b/plugin/metadata/metadata.go
index e7560d403..4abe57ddf 100644
--- a/plugin/metadata/metadata.go
+++ b/plugin/metadata/metadata.go
@@ -4,7 +4,6 @@ import (
"context"
"github.com/coredns/coredns/plugin"
- "github.com/coredns/coredns/plugin/pkg/variables"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@@ -24,18 +23,13 @@ func (m *Metadata) Name() string { return "metadata" }
// ServeDNS implements the plugin.Handler interface.
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
- ctx = context.WithValue(ctx, metadataKey{}, M{})
- md, _ := FromContext(ctx)
+ ctx = context.WithValue(ctx, key{}, md{})
state := request.Request{W: w, Req: r}
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
// Go through all Providers and collect metadata.
- for _, provider := range m.Providers {
- for _, varName := range provider.MetadataVarNames() {
- if val, ok := provider.Metadata(ctx, state, varName); ok {
- md.SetValue(varName, val)
- }
- }
+ for _, p := range m.Providers {
+ ctx = p.Metadata(ctx, state)
}
}
@@ -43,14 +37,3 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
return rcode, err
}
-
-// MetadataVarNames implements the plugin.Provider interface.
-func (m *Metadata) MetadataVarNames() []string { return variables.All }
-
-// Metadata implements the plugin.Provider interface.
-func (m *Metadata) Metadata(ctx context.Context, state request.Request, varName string) (interface{}, bool) {
- if val, err := variables.GetValue(state, varName); err == nil {
- return val, true
- }
- return nil, false
-}
diff --git a/plugin/metadata/metadata_test.go b/plugin/metadata/metadata_test.go
index 8bbff4c34..7ded05c03 100644
--- a/plugin/metadata/metadata_test.go
+++ b/plugin/metadata/metadata_test.go
@@ -10,26 +10,18 @@ import (
"github.com/miekg/dns"
)
-// testProvider implements fake Providers. Plugins which inmplement Provider interface
-type testProvider map[string]interface{}
+type testProvider map[string]Func
-func (m testProvider) MetadataVarNames() []string {
- keys := []string{}
- for k := range m {
- keys = append(keys, k)
+func (tp testProvider) Metadata(ctx context.Context, state request.Request) context.Context {
+ for k, v := range tp {
+ SetValueFunc(ctx, k, v)
}
- return keys
+ return ctx
}
-func (m testProvider) Metadata(ctx context.Context, state request.Request, key string) (val interface{}, ok bool) {
- value, ok := m[key]
- return value, ok
-}
-
-// testHandler implements plugin.Handler.
type testHandler struct{ ctx context.Context }
-func (m *testHandler) Name() string { return "testHandler" }
+func (m *testHandler) Name() string { return "test" }
func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m.ctx = ctx
@@ -38,8 +30,8 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
func TestMetadataServeDNS(t *testing.T) {
expectedMetadata := []testProvider{
- testProvider{"testkey1": "testvalue1"},
- testProvider{"testkey2": 2, "testkey3": "testvalue3"},
+ testProvider{"test/key1": func() string { return "testvalue1" }},
+ testProvider{"test/key2": func() string { return "two" }, "test/key3": func() string { return "testvalue3" }},
}
// Create fake Providers based on expectedMetadata
providers := []Provider{}
@@ -48,32 +40,22 @@ func TestMetadataServeDNS(t *testing.T) {
}
next := &testHandler{} // fake handler which stores the resulting context
- metadata := Metadata{
+ m := Metadata{
Zones: []string{"."},
Providers: providers,
Next: next,
}
- metadata.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg))
- // Verify that next plugin can find metadata in context from all Providers
+ ctx := context.TODO()
+ m.ServeDNS(ctx, &test.ResponseWriter{}, new(dns.Msg))
+ nctx := next.ctx
+
for _, expected := range expectedMetadata {
- md, ok := FromContext(next.ctx)
- if !ok {
- t.Fatalf("Metadata is expected but not present inside the context")
- }
- for expKey, expVal := range expected {
- metadataVal, valOk := md.Value(expKey)
- if !valOk {
- t.Fatalf("Value by key %v can't be retrieved", expKey)
- }
- if metadataVal != expVal {
- t.Errorf("Expected value %v, but got %v", expVal, metadataVal)
+ for label, expVal := range expected {
+ val := ValueFunc(nctx, label)
+ if val() != expVal() {
+ t.Errorf("Expected value %s for %s, but got %s", expVal(), label, val())
}
}
- wrongKey := "wrong_key"
- metadataVal, ok := md.Value(wrongKey)
- if ok {
- t.Fatalf("Value by key %v is not expected to be recieved, but got: %v", wrongKey, metadataVal)
- }
}
}
diff --git a/plugin/metadata/provider.go b/plugin/metadata/provider.go
index e13f9c896..eb7bb9755 100644
--- a/plugin/metadata/provider.go
+++ b/plugin/metadata/provider.go
@@ -1,3 +1,33 @@
+// Package metadata provides an API that allows plugins to add metadata to the context.
+// Each metadata is stored under a label that has the form <plugin>/<name>. Each metadata
+// is returned as a Func. When Func is called the metadata is returned. If Func is expensive to
+// execute it is its responsibility to provide some form of caching. During the handling of a
+// query it is expected the metadata stays constant.
+//
+// Basic example:
+//
+// Implement the Provder interface for a plugin:
+//
+// func (p P) Metadata(ctx context.Context, state request.Request) context.Context {
+// cached := ""
+// f := func() string {
+// if cached != "" {
+// return cached
+// }
+// cached = expensiveFunc()
+// return cached
+// }
+// metadata.SetValueFunc(ctx, "test/something", f)
+// return ctx
+// }
+//
+// Check the metadata from another plugin:
+//
+// // ...
+// valueFunc := metadata.ValueFunc(ctx, "test/something")
+// value := valueFunc()
+// // use 'value'
+//
package metadata
import (
@@ -8,40 +38,62 @@ import (
// Provider interface needs to be implemented by each plugin willing to provide
// metadata information for other plugins.
-// Note: this method should work quickly, because it is called for every request
-// from the metadata plugin.
type Provider interface {
- // List of variables which are provided by current Provider. Must remain constant.
- MetadataVarNames() []string
- // Metadata is expected to return a value with metadata information by the key
- // from 4th argument. Value can be later retrieved from context by any other plugin.
- // If value is not available by some reason returned boolean value should be false.
- Metadata(ctx context.Context, state request.Request, variable string) (interface{}, bool)
+ // Metadata adds metadata to the context and returns a (potentially) new context.
+ // Note: this method should work quickly, because it is called for every request
+ // from the metadata plugin.
+ Metadata(ctx context.Context, state request.Request) context.Context
}
-// M is metadata information storage.
-type M map[string]interface{}
+// Func is the type of function in the metadata, when called they return the value of the label.
+type Func func() string
-// FromContext retrieves the metadata from the context.
-func FromContext(ctx context.Context) (M, bool) {
- if metadata := ctx.Value(metadataKey{}); metadata != nil {
- if m, ok := metadata.(M); ok {
- return m, true
+// Labels returns all metadata keys stored in the context. These label names should be named
+// as: plugin/NAME, where NAME is something descriptive.
+func Labels(ctx context.Context) []string {
+ if metadata := ctx.Value(key{}); metadata != nil {
+ if m, ok := metadata.(md); ok {
+ return keys(m)
}
}
- return M{}, false
+ return nil
}
-// Value returns metadata value by key.
-func (m M) Value(key string) (value interface{}, ok bool) {
- value, ok = m[key]
- return value, ok
+// ValueFunc returns the value function of label. If none can be found nil is returned. Calling the
+// function returns the value of the label.
+func ValueFunc(ctx context.Context, label string) Func {
+ if metadata := ctx.Value(key{}); metadata != nil {
+ if m, ok := metadata.(md); ok {
+ return m[label]
+ }
+ }
+ return nil
}
-// SetValue sets the metadata value under key.
-func (m M) SetValue(key string, val interface{}) {
- m[key] = val
+// SetValueFunc set the metadata label to the value function. If no metadata can be found this is a noop and
+// false is returned. Any existing value is overwritten.
+func SetValueFunc(ctx context.Context, label string, f Func) bool {
+ if metadata := ctx.Value(key{}); metadata != nil {
+ if m, ok := metadata.(md); ok {
+ m[label] = f
+ return true
+ }
+ }
+ return false
}
-// metadataKey defines the type of key that is used to save metadata into the context.
-type metadataKey struct{}
+// md is metadata information storage.
+type md map[string]Func
+
+// key defines the type of key that is used to save metadata into the context.
+type key struct{}
+
+func keys(m map[string]Func) []string {
+ s := make([]string, len(m))
+ i := 0
+ for k := range m {
+ s[i] = k
+ i++
+ }
+ return s
+}
diff --git a/plugin/metadata/provider_test.go b/plugin/metadata/provider_test.go
deleted file mode 100644
index 1a074aeaa..000000000
--- a/plugin/metadata/provider_test.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package metadata
-
-import (
- "context"
- "reflect"
- "testing"
-)
-
-func TestMD(t *testing.T) {
- tests := []struct {
- addValues map[string]interface{}
- expectedValues map[string]interface{}
- }{
- {
- // Add initial metadata key/vals
- map[string]interface{}{"key1": "val1", "key2": 2},
- map[string]interface{}{"key1": "val1", "key2": 2},
- },
- {
- // Add additional key/vals.
- map[string]interface{}{"key3": 3, "key4": 4.5},
- map[string]interface{}{"key1": "val1", "key2": 2, "key3": 3, "key4": 4.5},
- },
- }
-
- // Using one same md and ctx for all test cases
- ctx := context.TODO()
- ctx = context.WithValue(ctx, metadataKey{}, M{})
- m, _ := FromContext(ctx)
-
- for i, tc := range tests {
- for k, v := range tc.addValues {
- m.SetValue(k, v)
- }
- if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(m)) {
- t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, m)
- }
-
- // Make sure that md is recieved from context successfullly
- mFromContext, ok := FromContext(ctx)
- if !ok {
- t.Errorf("Test %d: md is not recieved from the context", i)
- }
- if !reflect.DeepEqual(m, mFromContext) {
- t.Errorf("Test %d: md recieved from context differs from initial. Initial: %v, from context: %v", i, m, mFromContext)
- }
- }
-}
diff --git a/plugin/metadata/setup.go b/plugin/metadata/setup.go
index 33a153a2c..282bcf7d9 100644
--- a/plugin/metadata/setup.go
+++ b/plugin/metadata/setup.go
@@ -1,8 +1,6 @@
package metadata
import (
- "fmt"
-
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
@@ -28,16 +26,8 @@ func setup(c *caddy.Controller) error {
c.OnStartup(func() error {
plugins := dnsserver.GetConfig(c).Handlers()
- // Collect all plugins which implement Provider interface
- metadataVariables := map[string]bool{}
for _, p := range plugins {
if met, ok := p.(Provider); ok {
- for _, varName := range met.MetadataVarNames() {
- if _, ok := metadataVariables[varName]; ok {
- return fmt.Errorf("Metadata variable '%v' has duplicates", varName)
- }
- metadataVariables[varName] = true
- }
m.Providers = append(m.Providers, met)
}
}
diff --git a/plugin/pkg/variables/variables.go b/plugin/pkg/variables/variables.go
deleted file mode 100644
index 8e1cdbe77..000000000
--- a/plugin/pkg/variables/variables.go
+++ /dev/null
@@ -1,104 +0,0 @@
-package variables
-
-import (
- "encoding/binary"
- "fmt"
- "net"
- "strconv"
-
- "github.com/coredns/coredns/request"
-)
-
-const (
- queryName = "qname"
- queryType = "qtype"
- clientIP = "client_ip"
- clientPort = "client_port"
- protocol = "protocol"
- serverIP = "server_ip"
- serverPort = "server_port"
-)
-
-// All is a list of available variables provided by GetMetadataValue
-var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverIP, serverPort}
-
-// GetValue calculates and returns the data specified by the variable name.
-// Supported varNames are listed in allProvidedVars.
-func GetValue(state request.Request, varName string) ([]byte, error) {
- switch varName {
- case queryName:
- return []byte(state.QName()), nil
-
- case queryType:
- return uint16ToWire(state.QType()), nil
-
- case clientIP:
- return ipToWire(state.Family(), state.IP())
-
- case clientPort:
- return portToWire(state.Port())
-
- case protocol:
- return []byte(state.Proto()), nil
-
- case serverIP:
- ip, _, err := net.SplitHostPort(state.W.LocalAddr().String())
- if err != nil {
- ip = state.W.RemoteAddr().String()
- }
- return ipToWire(state.Family(), ip)
-
- case serverPort:
- _, port, err := net.SplitHostPort(state.W.LocalAddr().String())
- if err != nil {
- port = "0"
- }
- return portToWire(port)
- }
-
- return nil, fmt.Errorf("unable to extract data for variable %s", varName)
-}
-
-// uint16ToWire writes unit16 to wire/binary format
-func uint16ToWire(data uint16) []byte {
- buf := make([]byte, 2)
- binary.BigEndian.PutUint16(buf, uint16(data))
- return buf
-}
-
-// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
-func ipToWire(family int, ipAddr string) ([]byte, error) {
-
- switch family {
- case 1:
- return net.ParseIP(ipAddr).To4(), nil
- case 2:
- return net.ParseIP(ipAddr).To16(), nil
- }
- return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
-}
-
-// portToWire writes port to wire/binary format, 2 bytes
-func portToWire(portStr string) ([]byte, error) {
-
- port, err := strconv.ParseUint(portStr, 10, 16)
- if err != nil {
- return nil, err
- }
- return uint16ToWire(uint16(port)), nil
-}
-
-// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
-func family(ip net.Addr) int {
- var a net.IP
- if i, ok := ip.(*net.UDPAddr); ok {
- a = i.IP
- }
- if i, ok := ip.(*net.TCPAddr); ok {
- a = i.IP
- }
- if a.To4() != nil {
- return 1
- }
- return 2
-}
diff --git a/plugin/pkg/variables/variables_test.go b/plugin/pkg/variables/variables_test.go
deleted file mode 100644
index e0ff64c19..000000000
--- a/plugin/pkg/variables/variables_test.go
+++ /dev/null
@@ -1,83 +0,0 @@
-package variables
-
-import (
- "bytes"
- "testing"
-
- "github.com/coredns/coredns/plugin/test"
- "github.com/coredns/coredns/request"
-
- "github.com/miekg/dns"
-)
-
-func TestGetValue(t *testing.T) {
- // test.ResponseWriter has the following values:
- // The remote will always be 10.240.0.1 and port 40212.
- // The local address is always 127.0.0.1 and port 53.
- tests := []struct {
- varName string
- expectedValue []byte
- shouldErr bool
- }{
- {
- queryName,
- []byte("example.com."),
- false,
- },
- {
- queryType,
- []byte{0x00, 0x01},
- false,
- },
- {
- clientIP,
- []byte{10, 240, 0, 1},
- false,
- },
- {
- clientPort,
- []byte{0x9D, 0x14},
- false,
- },
- {
- protocol,
- []byte("udp"),
- false,
- },
- {
- serverIP,
- []byte{127, 0, 0, 1},
- false,
- },
- {
- serverPort,
- []byte{0, 53},
- false,
- },
- {
- "wrong_var",
- []byte{},
- true,
- },
- }
-
- for i, tc := range tests {
- m := new(dns.Msg)
- m.SetQuestion("example.com.", dns.TypeA)
- m.Question[0].Qclass = dns.ClassINET
- state := request.Request{W: &test.ResponseWriter{}, Req: m}
-
- value, err := GetValue(state, tc.varName)
-
- if tc.shouldErr && err == nil {
- t.Errorf("Test %d: Expected error, but didn't recieve", i)
- }
- if !tc.shouldErr && err != nil {
- t.Errorf("Test %d: Expected no error, but got error: %v", i, err.Error())
- }
-
- if !bytes.Equal(tc.expectedValue, value) {
- t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValue, value)
- }
- }
-}
diff --git a/plugin/rewrite/README.md b/plugin/rewrite/README.md
index 4e2e49a3a..680e69722 100644
--- a/plugin/rewrite/README.md
+++ b/plugin/rewrite/README.md
@@ -206,17 +206,13 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq
}
~~~
-* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables by default:
+* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables:
{qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}.
-Any plugin that can provide it's own additional variables by implementing metadata.Provider interface. If you are going to use metadata variables then metadata plugin must be enabled.
Example:
-~~~ corefile
-. {
- metadata
- rewrite edns0 local set 0xffee {client_ip}
-}
+~~~
+rewrite edns0 local set 0xffee {client_ip}
~~~
### EDNS0_NSID
diff --git a/plugin/rewrite/class.go b/plugin/rewrite/class.go
index b04dabce2..2e54f515c 100644
--- a/plugin/rewrite/class.go
+++ b/plugin/rewrite/class.go
@@ -1,7 +1,6 @@
package rewrite
import (
- "context"
"fmt"
"strings"
@@ -28,7 +27,7 @@ func newClassRule(nextAction string, args ...string) (Rule, error) {
}
// Rewrite rewrites the the current request.
-func (rule *classRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromClass > 0 && rule.toClass > 0 {
if r.Question[0].Qclass == rule.fromClass {
r.Question[0].Qclass = rule.toClass
diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go
index f59cee1e0..a651744d2 100644
--- a/plugin/rewrite/edns0.go
+++ b/plugin/rewrite/edns0.go
@@ -2,15 +2,13 @@
package rewrite
import (
- "context"
+ "encoding/binary"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
- "github.com/coredns/coredns/plugin/metadata"
- "github.com/coredns/coredns/plugin/pkg/variables"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@@ -49,7 +47,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT {
}
// Rewrite will alter the request EDNS0 NSID option
-func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@@ -86,7 +84,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule {
}
// Rewrite will alter the request EDNS0 local options
-func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@@ -149,9 +147,7 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) {
}
//Check for variable option
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
- // Remove first and last runes
- variable := args[3][1 : len(args[3])-1]
- return newEdns0VariableRule(mode, action, args[2], variable)
+ return newEdns0VariableRule(mode, action, args[2], args[3])
}
return newEdns0LocalRule(mode, action, args[2], args[3])
case "nsid":
@@ -191,29 +187,102 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu
if err != nil {
return nil, err
}
+ //Validate
+ if !isValidVariable(variable) {
+ return nil, fmt.Errorf("unsupported variable name %q", variable)
+ }
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil
}
+// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
+func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) {
+
+ switch family {
+ case 1:
+ return net.ParseIP(ipAddr).To4(), nil
+ case 2:
+ return net.ParseIP(ipAddr).To16(), nil
+ }
+ return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
+}
+
+// uint16ToWire writes unit16 to wire/binary format
+func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte {
+ buf := make([]byte, 2)
+ binary.BigEndian.PutUint16(buf, uint16(data))
+ return buf
+}
+
+// portToWire writes port to wire/binary format, 2 bytes
+func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) {
+
+ port, err := strconv.ParseUint(portStr, 10, 16)
+ if err != nil {
+ return nil, err
+ }
+ return rule.uint16ToWire(uint16(port)), nil
+}
+
+// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
+func (rule *edns0VariableRule) family(ip net.Addr) int {
+ var a net.IP
+ if i, ok := ip.(*net.UDPAddr); ok {
+ a = i.IP
+ }
+ if i, ok := ip.(*net.TCPAddr); ok {
+ a = i.IP
+ }
+ if a.To4() != nil {
+ return 1
+ }
+ return 2
+}
+
// ruleData returns the data specified by the variable
-func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
- if md, ok := metadata.FromContext(ctx); ok {
- if value, ok := md.Value(rule.variable); ok {
- if v, ok := value.([]byte); ok {
- return v, nil
- }
+func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
+
+ req := request.Request{W: w, Req: r}
+ switch rule.variable {
+ case queryName:
+ //Query name is written as ascii string
+ return []byte(req.QName()), nil
+
+ case queryType:
+ return rule.uint16ToWire(req.QType()), nil
+
+ case clientIP:
+ return rule.ipToWire(req.Family(), req.IP())
+
+ case clientPort:
+ return rule.portToWire(req.Port())
+
+ case protocol:
+ // Proto is written as ascii string
+ return []byte(req.Proto()), nil
+
+ case serverIP:
+ ip, _, err := net.SplitHostPort(w.LocalAddr().String())
+ if err != nil {
+ ip = w.RemoteAddr().String()
}
- } else { // No metadata available means metadata plugin is disabled. Try to get the value directly.
- state := request.Request{W: w, Req: r} // TODO(miek): every rule needs to take a request.Request.
- return variables.GetValue(state, rule.variable)
+ return rule.ipToWire(rule.family(w.RemoteAddr()), ip)
+
+ case serverPort:
+ _, port, err := net.SplitHostPort(w.LocalAddr().String())
+ if err != nil {
+ port = "0"
+ }
+ return rule.portToWire(port)
}
+
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
}
// Rewrite will alter the request EDNS0 local options with specified variables
-func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
- data, err := rule.ruleData(ctx, w, r)
+ data, err := rule.ruleData(w, r)
if err != nil || data == nil {
return result
}
@@ -256,6 +325,21 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule {
return ResponseRule{}
}
+func isValidVariable(variable string) bool {
+ switch variable {
+ case
+ queryName,
+ queryType,
+ clientIP,
+ clientPort,
+ protocol,
+ serverIP,
+ serverPort:
+ return true
+ }
+ return false
+}
+
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
type edns0SubnetRule struct {
mode string
@@ -316,7 +400,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, ecs *
}
// Rewrite will alter the request EDNS0 subnet option
-func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@@ -362,6 +446,17 @@ const (
Append = "append"
)
+// Supported local EDNS0 variables
+const (
+ queryName = "{qname}"
+ queryType = "{qtype}"
+ clientIP = "{client_ip}"
+ clientPort = "{client_port}"
+ protocol = "{protocol}"
+ serverIP = "{server_ip}"
+ serverPort = "{server_port}"
+)
+
// Subnet maximum bit mask length
const (
maxV4BitMaskLen = 32
diff --git a/plugin/rewrite/name.go b/plugin/rewrite/name.go
index e06c39078..eb2ac7285 100644
--- a/plugin/rewrite/name.go
+++ b/plugin/rewrite/name.go
@@ -1,7 +1,6 @@
package rewrite
import (
- "context"
"fmt"
"regexp"
"strconv"
@@ -58,7 +57,7 @@ const (
// Rewrite rewrites the current request based upon exact match of the name
// in the question section of the request
-func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.From == r.Question[0].Name {
r.Question[0].Name = rule.To
return RewriteDone
@@ -67,7 +66,7 @@ func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.
}
// Rewrite rewrites the current request when the name begins with the matching string
-func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if strings.HasPrefix(r.Question[0].Name, rule.Prefix) {
r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix)
return RewriteDone
@@ -76,7 +75,7 @@ func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r
}
// Rewrite rewrites the current request when the name ends with the matching string
-func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if strings.HasSuffix(r.Question[0].Name, rule.Suffix) {
r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement
return RewriteDone
@@ -86,7 +85,7 @@ func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r
// Rewrite rewrites the current request based upon partial match of the
// name in the question section of the request
-func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if strings.Contains(r.Question[0].Name, rule.Substring) {
r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1)
return RewriteDone
@@ -96,7 +95,7 @@ func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter
// Rewrite rewrites the current request when the name in the question
// section of the request matches a regular expression
-func (rule *regexNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *regexNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name)
if len(regexGroups) == 0 {
return RewriteIgnored
diff --git a/plugin/rewrite/rewrite.go b/plugin/rewrite/rewrite.go
index 3ec58d32c..e340fa3ca 100644
--- a/plugin/rewrite/rewrite.go
+++ b/plugin/rewrite/rewrite.go
@@ -39,7 +39,7 @@ type Rewrite struct {
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules {
- switch result := rule.Rewrite(ctx, w, r); result {
+ switch result := rule.Rewrite(w, r); result {
case RewriteDone:
respRule := rule.GetResponseRule()
if respRule.Active == true {
@@ -68,7 +68,7 @@ func (rw Rewrite) Name() string { return "rewrite" }
// Rule describes a rewrite rule.
type Rule interface {
// Rewrite rewrites the current request.
- Rewrite(context.Context, dns.ResponseWriter, *dns.Msg) Result
+ Rewrite(dns.ResponseWriter, *dns.Msg) Result
// Mode returns the processing mode stop or continue.
Mode() string
// GetResponseRule returns the rule to rewrite response with, if any.
diff --git a/plugin/rewrite/rewrite_test.go b/plugin/rewrite/rewrite_test.go
index b35543b9b..56c446f49 100644
--- a/plugin/rewrite/rewrite_test.go
+++ b/plugin/rewrite/rewrite_test.go
@@ -71,7 +71,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "foo"}, true, nil},
- {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
+ {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
@@ -79,7 +79,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
- {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
+ {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
@@ -87,7 +87,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
- {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
+ {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
diff --git a/plugin/rewrite/type.go b/plugin/rewrite/type.go
index c5c545485..ec36b0b0a 100644
--- a/plugin/rewrite/type.go
+++ b/plugin/rewrite/type.go
@@ -2,7 +2,6 @@
package rewrite
import (
- "context"
"fmt"
"strings"
@@ -29,7 +28,7 @@ func newTypeRule(nextAction string, args ...string) (Rule, error) {
}
// Rewrite rewrites the the current request.
-func (rule *typeRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
+func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromType > 0 && rule.toType > 0 {
if r.Question[0].Qtype == rule.fromType {
r.Question[0].Qtype = rule.toType