aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/dnsserver/server_https.go4
-rw-r--r--plugin/trace/trace.go11
-rw-r--r--plugin/trace/trace_test.go40
3 files changed, 54 insertions, 1 deletions
diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go
index ba6097215..5c884e56b 100644
--- a/core/dnsserver/server_https.go
+++ b/core/dnsserver/server_https.go
@@ -27,6 +27,9 @@ type ServerHTTPS struct {
validRequest func(*http.Request) bool
}
+// HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH)
+type HTTPRequestKey struct{}
+
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
s, err := NewServer(addr, group)
@@ -153,6 +156,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// We should expect a packet to be returned that we can send to the client.
ctx := context.WithValue(context.Background(), Key{}, s.Server)
ctx = context.WithValue(ctx, LoopKey{}, 0)
+ ctx = context.WithValue(ctx, HTTPRequestKey{}, r)
s.ServeDNS(ctx, dw, msg)
// See section 4.2.1 of RFC 8484.
diff --git a/plugin/trace/trace.go b/plugin/trace/trace.go
index 87cb65e68..6bfd94dae 100644
--- a/plugin/trace/trace.go
+++ b/plugin/trace/trace.go
@@ -4,9 +4,11 @@ package trace
import (
"context"
"fmt"
+ "net/http"
"sync"
"sync/atomic"
+ "github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/dnstest"
@@ -140,8 +142,15 @@ func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
+ var spanCtx ot.SpanContext
+ if val := ctx.Value(dnsserver.HTTPRequestKey{}); val != nil {
+ if httpReq, ok := val.(*http.Request); ok {
+ spanCtx, _ = t.Tracer().Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(httpReq.Header))
+ }
+ }
+
req := request.Request{W: w, Req: r}
- span = t.Tracer().StartSpan(defaultTopLevelSpanName)
+ span = t.Tracer().StartSpan(defaultTopLevelSpanName, otext.RPCServerOption(spanCtx))
defer span.Finish()
switch spanCtx := span.Context().(type) {
diff --git a/plugin/trace/trace_test.go b/plugin/trace/trace_test.go
index dae546f8d..940eb6b02 100644
--- a/plugin/trace/trace_test.go
+++ b/plugin/trace/trace_test.go
@@ -3,9 +3,11 @@ package trace
import (
"context"
"errors"
+ "net/http/httptest"
"testing"
"github.com/coredns/caddy"
+ "github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/rcode"
@@ -13,6 +15,7 @@ import (
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
+ "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
)
@@ -131,3 +134,40 @@ func TestTrace(t *testing.T) {
})
}
}
+
+func TestTrace_DOH_TraceHeaderExtraction(t *testing.T) {
+ w := dnstest.NewRecorder(&test.ResponseWriter{})
+ m := mocktracer.New()
+ tr := &trace{
+ Next: test.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ if plugin.ClientWrite(dns.RcodeSuccess) {
+ m := new(dns.Msg)
+ m.SetRcode(r, dns.RcodeSuccess)
+ w.WriteMsg(m)
+ }
+ return dns.RcodeSuccess, nil
+ }),
+ every: 1,
+ tracer: m,
+ }
+ q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)
+
+ req := httptest.NewRequest("POST", "/dns-query", nil)
+
+ outsideSpan := m.StartSpan("test-header-span")
+ outsideSpan.Tracer().Inject(outsideSpan.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
+ defer outsideSpan.Finish()
+
+ ctx := context.TODO()
+ ctx = context.WithValue(ctx, dnsserver.HTTPRequestKey{}, req)
+
+ tr.ServeDNS(ctx, w, q)
+
+ fs := m.FinishedSpans()
+ rootCoreDNSspan := fs[1]
+ rootCoreDNSTraceID := rootCoreDNSspan.Context().(mocktracer.MockSpanContext).TraceID
+ outsideSpanTraceID := outsideSpan.Context().(mocktracer.MockSpanContext).TraceID
+ if rootCoreDNSTraceID != outsideSpanTraceID {
+ t.Errorf("Unexpected traceID: rootSpan.TraceID: want %v, got %v", rootCoreDNSTraceID, outsideSpanTraceID)
+ }
+}