diff options
-rw-r--r-- | plugin/dnstap/context_test.go | 2 | ||||
-rw-r--r-- | plugin/dnstap/gocontext.go | 23 | ||||
-rw-r--r-- | plugin/dnstap/handler.go | 9 |
3 files changed, 26 insertions, 8 deletions
diff --git a/plugin/dnstap/context_test.go b/plugin/dnstap/context_test.go index 04728e032..64418f59b 100644 --- a/plugin/dnstap/context_test.go +++ b/plugin/dnstap/context_test.go @@ -6,7 +6,7 @@ import ( ) func TestDnstapContext(t *testing.T) { - ctx := tapContext{context.TODO(), Dnstap{}} + ctx := ContextWithTapper(context.TODO(), Dnstap{}) tapper := TapperFromContext(ctx) if tapper == nil { diff --git a/plugin/dnstap/gocontext.go b/plugin/dnstap/gocontext.go new file mode 100644 index 000000000..a8cc2c2b4 --- /dev/null +++ b/plugin/dnstap/gocontext.go @@ -0,0 +1,23 @@ +package dnstap + +import "context" + +type contextKey struct{} + +var dnstapKey = contextKey{} + +// ContextWithTapper returns a new `context.Context` that holds a reference to +// `t`'s Tapper. +func ContextWithTapper(ctx context.Context, t Tapper) context.Context { + return context.WithValue(ctx, dnstapKey, t) +} + +// TapperFromContext returns the `Tapper` previously associated with `ctx`, or +// `nil` if no such `Tapper` could be found. +func TapperFromContext(ctx context.Context) Tapper { + val := ctx.Value(dnstapKey) + if sp, ok := val.(Tapper); ok { + return sp + } + return nil +} diff --git a/plugin/dnstap/handler.go b/plugin/dnstap/handler.go index b09c70406..1178dad79 100644 --- a/plugin/dnstap/handler.go +++ b/plugin/dnstap/handler.go @@ -44,12 +44,6 @@ const ( DnstapSendOption ContextKey = "dnstap-send-option" ) -// TapperFromContext will return a Tapper if the dnstap plugin is enabled. -func TapperFromContext(ctx context.Context) (t Tapper) { - t, _ = ctx.(Tapper) - return -} - // TapMessage implements Tapper. func (h Dnstap) TapMessage(m *tap.Message) { t := tap.Dnstap_MESSAGE @@ -71,6 +65,7 @@ func (h Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) // message to be sent out sendOption := taprw.SendOption{Cq: true, Cr: true} newCtx := context.WithValue(ctx, DnstapSendOption, &sendOption) + newCtx = ContextWithTapper(newCtx, h) rw := &taprw.ResponseWriter{ ResponseWriter: w, @@ -80,7 +75,7 @@ func (h Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) QueryEpoch: time.Now(), } - code, err := plugin.NextOrFailure(h.Name(), h.Next, tapContext{newCtx, h}, rw, r) + code, err := plugin.NextOrFailure(h.Name(), h.Next, newCtx, rw, r) if err != nil { // ignore dnstap errors return code, err |