diff options
Diffstat (limited to 'middleware/proxy/grpc.go')
-rw-r--r-- | middleware/proxy/grpc.go | 28 |
1 files changed, 24 insertions, 4 deletions
diff --git a/middleware/proxy/grpc.go b/middleware/proxy/grpc.go index aaf908d2a..c480d3cf2 100644 --- a/middleware/proxy/grpc.go +++ b/middleware/proxy/grpc.go @@ -5,16 +5,22 @@ import ( "crypto/tls" "log" + "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" + "github.com/coredns/coredns/middleware/proxy/pb" + "github.com/coredns/coredns/middleware/trace" "github.com/coredns/coredns/request" "github.com/miekg/dns" + + opentracing "github.com/opentracing/opentracing-go" + "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) type grpcClient struct { - dialOpt grpc.DialOption + dialOpts []grpc.DialOption clients map[string]pb.DnsServiceClient conns []*grpc.ClientConn upstream *staticUpstream @@ -24,9 +30,9 @@ func newGrpcClient(tls *tls.Config, u *staticUpstream) *grpcClient { g := &grpcClient{upstream: u} if tls == nil { - g.dialOpt = grpc.WithInsecure() + g.dialOpts = append(g.dialOpts, grpc.WithInsecure()) } else { - g.dialOpt = grpc.WithTransportCredentials(credentials.NewTLS(tls)) + g.dialOpts = append(g.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tls))) } g.clients = map[string]pb.DnsServiceClient{} @@ -54,18 +60,32 @@ func (g *grpcClient) Exchange(ctx context.Context, addr string, state request.Re func (g *grpcClient) Protocol() string { return "grpc" } func (g *grpcClient) OnShutdown(p *Proxy) error { + g.clients = map[string]pb.DnsServiceClient{} for i, conn := range g.conns { err := conn.Close() if err != nil { log.Printf("[WARNING] Error closing connection %d: %s\n", i, err) } } + g.conns = []*grpc.ClientConn{} return nil } func (g *grpcClient) OnStartup(p *Proxy) error { + dialOpts := g.dialOpts + if p.Trace != nil { + if t, ok := p.Trace.(trace.Trace); ok { + onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}) bool { + return parentSpanCtx != nil + } + intercept := otgrpc.OpenTracingClientInterceptor(t.Tracer(), otgrpc.IncludingSpans(onlyIfParent)) + dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(intercept)) + } else { + log.Printf("[WARNING] Wrong type for trace middleware reference: %s", p.Trace) + } + } for _, host := range g.upstream.Hosts { - conn, err := grpc.Dial(host.Name, g.dialOpt) + conn, err := grpc.Dial(host.Name, dialOpts...) if err != nil { log.Printf("[WARNING] Skipping gRPC host '%s' due to Dial error: %s\n", host.Name, err) } else { |