aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--plugin/metadata/metadata.go13
-rw-r--r--plugin/metadata/metadata_test.go10
-rw-r--r--plugin/metadata/provider.go (renamed from plugin/metadata/metadataer.go)36
-rw-r--r--plugin/metadata/provider_test.go (renamed from plugin/metadata/metadataer_test.go)19
-rw-r--r--plugin/pkg/variables/variables.go25
-rw-r--r--plugin/pkg/variables/variables_test.go5
-rw-r--r--plugin/rewrite/edns0.go3
7 files changed, 53 insertions, 58 deletions
diff --git a/plugin/metadata/metadata.go b/plugin/metadata/metadata.go
index 1e840d3fd..e7560d403 100644
--- a/plugin/metadata/metadata.go
+++ b/plugin/metadata/metadata.go
@@ -24,15 +24,16 @@ func (m *Metadata) Name() string { return "metadata" }
// ServeDNS implements the plugin.Handler interface.
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
- md, ctx := newMD(ctx)
+ ctx = context.WithValue(ctx, metadataKey{}, M{})
+ md, _ := FromContext(ctx)
state := request.Request{W: w, Req: r}
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
- // Go through all Providers and collect metadata
+ // Go through all Providers and collect metadata.
for _, provider := range m.Providers {
for _, varName := range provider.MetadataVarNames() {
- if val, ok := provider.Metadata(ctx, w, r, varName); ok {
- md.setValue(varName, val)
+ if val, ok := provider.Metadata(ctx, state, varName); ok {
+ md.SetValue(varName, val)
}
}
}
@@ -47,8 +48,8 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
func (m *Metadata) MetadataVarNames() []string { return variables.All }
// Metadata implements the plugin.Provider interface.
-func (m *Metadata) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, varName string) (interface{}, bool) {
- if val, err := variables.GetValue(varName, w, r); err == nil {
+func (m *Metadata) Metadata(ctx context.Context, state request.Request, varName string) (interface{}, bool) {
+ if val, err := variables.GetValue(state, varName); err == nil {
return val, true
}
return nil, false
diff --git a/plugin/metadata/metadata_test.go b/plugin/metadata/metadata_test.go
index 413ba874e..8bbff4c34 100644
--- a/plugin/metadata/metadata_test.go
+++ b/plugin/metadata/metadata_test.go
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/coredns/coredns/plugin/test"
+ "github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
@@ -20,12 +21,12 @@ func (m testProvider) MetadataVarNames() []string {
return keys
}
-func (m testProvider) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, key string) (val interface{}, ok bool) {
+func (m testProvider) Metadata(ctx context.Context, state request.Request, key string) (val interface{}, ok bool) {
value, ok := m[key]
return value, ok
}
-// testHandler implements plugin.Handler
+// testHandler implements plugin.Handler.
type testHandler struct{ ctx context.Context }
func (m *testHandler) Name() string { return "testHandler" }
@@ -35,7 +36,7 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
return 0, nil
}
-func TestMetadataServDns(t *testing.T) {
+func TestMetadataServeDNS(t *testing.T) {
expectedMetadata := []testProvider{
testProvider{"testkey1": "testvalue1"},
testProvider{"testkey2": 2, "testkey3": "testvalue3"},
@@ -45,9 +46,8 @@ func TestMetadataServDns(t *testing.T) {
for _, e := range expectedMetadata {
providers = append(providers, e)
}
- // Fake handler which stores the resulting context
- next := &testHandler{}
+ next := &testHandler{} // fake handler which stores the resulting context
metadata := Metadata{
Zones: []string{"."},
Providers: providers,
diff --git a/plugin/metadata/metadataer.go b/plugin/metadata/provider.go
index bff12e92d..e13f9c896 100644
--- a/plugin/metadata/metadataer.go
+++ b/plugin/metadata/provider.go
@@ -3,7 +3,7 @@ package metadata
import (
"context"
- "github.com/miekg/dns"
+ "github.com/coredns/coredns/request"
)
// Provider interface needs to be implemented by each plugin willing to provide
@@ -16,38 +16,32 @@ type Provider interface {
// Metadata is expected to return a value with metadata information by the key
// from 4th argument. Value can be later retrieved from context by any other plugin.
// If value is not available by some reason returned boolean value should be false.
- Metadata(context.Context, dns.ResponseWriter, *dns.Msg, string) (interface{}, bool)
+ Metadata(ctx context.Context, state request.Request, variable string) (interface{}, bool)
}
-// MD is metadata information storage
-type MD map[string]interface{}
+// M is metadata information storage.
+type M map[string]interface{}
-// metadataKey defines the type of key that is used to save metadata into the context
-type metadataKey struct{}
-
-// newMD initializes MD and attaches it to context
-func newMD(ctx context.Context) (MD, context.Context) {
- m := MD{}
- return m, context.WithValue(ctx, metadataKey{}, m)
-}
-
-// FromContext retrieves MD struct from context.
-func FromContext(ctx context.Context) (md MD, ok bool) {
+// FromContext retrieves the metadata from the context.
+func FromContext(ctx context.Context) (M, bool) {
if metadata := ctx.Value(metadataKey{}); metadata != nil {
- if md, ok := metadata.(MD); ok {
- return md, true
+ if m, ok := metadata.(M); ok {
+ return m, true
}
}
- return MD{}, false
+ return M{}, false
}
// Value returns metadata value by key.
-func (m MD) Value(key string) (value interface{}, ok bool) {
+func (m M) Value(key string) (value interface{}, ok bool) {
value, ok = m[key]
return value, ok
}
-// setValue adds metadata value.
-func (m MD) setValue(key string, val interface{}) {
+// SetValue sets the metadata value under key.
+func (m M) SetValue(key string, val interface{}) {
m[key] = val
}
+
+// metadataKey defines the type of key that is used to save metadata into the context.
+type metadataKey struct{}
diff --git a/plugin/metadata/metadataer_test.go b/plugin/metadata/provider_test.go
index 53096feb8..1a074aeaa 100644
--- a/plugin/metadata/metadataer_test.go
+++ b/plugin/metadata/provider_test.go
@@ -25,23 +25,24 @@ func TestMD(t *testing.T) {
// Using one same md and ctx for all test cases
ctx := context.TODO()
- md, ctx := newMD(ctx)
+ ctx = context.WithValue(ctx, metadataKey{}, M{})
+ m, _ := FromContext(ctx)
for i, tc := range tests {
for k, v := range tc.addValues {
- md.setValue(k, v)
+ m.SetValue(k, v)
}
- if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(md)) {
- t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, md)
+ if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(m)) {
+ t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, m)
}
- // Make sure that MD is recieved from context successfullly
- mdFromContext, ok := FromContext(ctx)
+ // Make sure that md is recieved from context successfullly
+ mFromContext, ok := FromContext(ctx)
if !ok {
- t.Errorf("Test %d: MD is not recieved from the context", i)
+ t.Errorf("Test %d: md is not recieved from the context", i)
}
- if !reflect.DeepEqual(md, mdFromContext) {
- t.Errorf("Test %d: MD recieved from context differs from initial. Initial: %v, from context: %v", i, md, mdFromContext)
+ if !reflect.DeepEqual(m, mFromContext) {
+ t.Errorf("Test %d: md recieved from context differs from initial. Initial: %v, from context: %v", i, m, mFromContext)
}
}
}
diff --git a/plugin/pkg/variables/variables.go b/plugin/pkg/variables/variables.go
index da1dccbee..8e1cdbe77 100644
--- a/plugin/pkg/variables/variables.go
+++ b/plugin/pkg/variables/variables.go
@@ -7,8 +7,6 @@ import (
"strconv"
"github.com/coredns/coredns/request"
-
- "github.com/miekg/dns"
)
const (
@@ -26,35 +24,32 @@ var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverI
// GetValue calculates and returns the data specified by the variable name.
// Supported varNames are listed in allProvidedVars.
-func GetValue(varName string, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
- req := request.Request{W: w, Req: r}
+func GetValue(state request.Request, varName string) ([]byte, error) {
switch varName {
case queryName:
- //Query name is written as ascii string
- return []byte(req.QName()), nil
+ return []byte(state.QName()), nil
case queryType:
- return uint16ToWire(req.QType()), nil
+ return uint16ToWire(state.QType()), nil
case clientIP:
- return ipToWire(req.Family(), req.IP())
+ return ipToWire(state.Family(), state.IP())
case clientPort:
- return portToWire(req.Port())
+ return portToWire(state.Port())
case protocol:
- // Proto is written as ascii string
- return []byte(req.Proto()), nil
+ return []byte(state.Proto()), nil
case serverIP:
- ip, _, err := net.SplitHostPort(w.LocalAddr().String())
+ ip, _, err := net.SplitHostPort(state.W.LocalAddr().String())
if err != nil {
- ip = w.RemoteAddr().String()
+ ip = state.W.RemoteAddr().String()
}
- return ipToWire(family(w.RemoteAddr()), ip)
+ return ipToWire(state.Family(), ip)
case serverPort:
- _, port, err := net.SplitHostPort(w.LocalAddr().String())
+ _, port, err := net.SplitHostPort(state.W.LocalAddr().String())
if err != nil {
port = "0"
}
diff --git a/plugin/pkg/variables/variables_test.go b/plugin/pkg/variables/variables_test.go
index 939add323..e0ff64c19 100644
--- a/plugin/pkg/variables/variables_test.go
+++ b/plugin/pkg/variables/variables_test.go
@@ -5,6 +5,8 @@ import (
"testing"
"github.com/coredns/coredns/plugin/test"
+ "github.com/coredns/coredns/request"
+
"github.com/miekg/dns"
)
@@ -63,8 +65,9 @@ func TestGetValue(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET
+ state := request.Request{W: &test.ResponseWriter{}, Req: m}
- value, err := GetValue(tc.varName, &test.ResponseWriter{}, m)
+ value, err := GetValue(state, tc.varName)
if tc.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but didn't recieve", i)
diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go
index f8b65d468..2391936c7 100644
--- a/plugin/rewrite/edns0.go
+++ b/plugin/rewrite/edns0.go
@@ -202,7 +202,8 @@ func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWrite
}
}
} else { // No metadata available means metadata plugin is disabled. Try to get the value directly.
- return variables.GetValue(rule.variable, w, r)
+ state := request.Request{W: w, Req: r} // TODO(miek): every rule needs to take a request.Request.
+ return variables.GetValue(state, rule.variable)
}
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
}