diff options
Diffstat (limited to 'middleware/proxy')
-rw-r--r-- | middleware/proxy/grpc.go | 28 | ||||
-rw-r--r-- | middleware/proxy/grpc_test.go | 54 | ||||
-rw-r--r-- | middleware/proxy/proxy.go | 4 | ||||
-rw-r--r-- | middleware/proxy/setup.go | 3 |
4 files changed, 84 insertions, 5 deletions
diff --git a/middleware/proxy/grpc.go b/middleware/proxy/grpc.go index aaf908d2a..c480d3cf2 100644 --- a/middleware/proxy/grpc.go +++ b/middleware/proxy/grpc.go @@ -5,16 +5,22 @@ import ( "crypto/tls" "log" + "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" + "github.com/coredns/coredns/middleware/proxy/pb" + "github.com/coredns/coredns/middleware/trace" "github.com/coredns/coredns/request" "github.com/miekg/dns" + + opentracing "github.com/opentracing/opentracing-go" + "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) type grpcClient struct { - dialOpt grpc.DialOption + dialOpts []grpc.DialOption clients map[string]pb.DnsServiceClient conns []*grpc.ClientConn upstream *staticUpstream @@ -24,9 +30,9 @@ func newGrpcClient(tls *tls.Config, u *staticUpstream) *grpcClient { g := &grpcClient{upstream: u} if tls == nil { - g.dialOpt = grpc.WithInsecure() + g.dialOpts = append(g.dialOpts, grpc.WithInsecure()) } else { - g.dialOpt = grpc.WithTransportCredentials(credentials.NewTLS(tls)) + g.dialOpts = append(g.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tls))) } g.clients = map[string]pb.DnsServiceClient{} @@ -54,18 +60,32 @@ func (g *grpcClient) Exchange(ctx context.Context, addr string, state request.Re func (g *grpcClient) Protocol() string { return "grpc" } func (g *grpcClient) OnShutdown(p *Proxy) error { + g.clients = map[string]pb.DnsServiceClient{} for i, conn := range g.conns { err := conn.Close() if err != nil { log.Printf("[WARNING] Error closing connection %d: %s\n", i, err) } } + g.conns = []*grpc.ClientConn{} return nil } func (g *grpcClient) OnStartup(p *Proxy) error { + dialOpts := g.dialOpts + if p.Trace != nil { + if t, ok := p.Trace.(trace.Trace); ok { + onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}) bool { + return parentSpanCtx != nil + } + intercept := otgrpc.OpenTracingClientInterceptor(t.Tracer(), otgrpc.IncludingSpans(onlyIfParent)) + dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(intercept)) + } else { + log.Printf("[WARNING] Wrong type for trace middleware reference: %s", p.Trace) + } + } for _, host := range g.upstream.Hosts { - conn, err := grpc.Dial(host.Name, g.dialOpt) + conn, err := grpc.Dial(host.Name, dialOpts...) if err != nil { log.Printf("[WARNING] Skipping gRPC host '%s' due to Dial error: %s\n", host.Name, err) } else { diff --git a/middleware/proxy/grpc_test.go b/middleware/proxy/grpc_test.go new file mode 100644 index 000000000..0eade58a9 --- /dev/null +++ b/middleware/proxy/grpc_test.go @@ -0,0 +1,54 @@ +package proxy + +import ( + "testing" + "time" +) + +func pool() []*UpstreamHost { + return []*UpstreamHost{ + { + Name: "localhost:10053", + }, + { + Name: "localhost:10054", + }, + } +} + +func TestStartupShutdown(t *testing.T) { + upstream := &staticUpstream{ + from: ".", + Hosts: pool(), + Policy: &Random{}, + Spray: nil, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + g := newGrpcClient(nil, upstream) + upstream.ex = g + + p := &Proxy{Trace: nil} + p.Upstreams = &[]Upstream{upstream} + + err := g.OnStartup(p) + if err != nil { + t.Errorf("Error starting grpc client exchanger: %s", err) + return + } + if len(g.clients) != len(pool()) { + t.Errorf("Expected %d grpc clients but found %d", len(pool()), len(g.clients)) + } + + err = g.OnShutdown(p) + if err != nil { + t.Errorf("Error stopping grpc client exchanger: %s", err) + return + } + if len(g.clients) != 0 { + t.Errorf("Shutdown didn't remove clients, found %d", len(g.clients)) + } + if len(g.conns) != 0 { + t.Errorf("Shutdown didn't remove conns, found %d", len(g.conns)) + } +} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 090c070cb..9457fb2a1 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -28,6 +28,10 @@ type Proxy struct { // midway. Upstreams *[]Upstream + + // Trace is the Trace middleware, if it is installed + // This is used by the grpc exchanger to trace through the grpc calls + Trace middleware.Handler } // Upstream manages a pool of proxy upstream hosts. Select should return a diff --git a/middleware/proxy/setup.go b/middleware/proxy/setup.go index 3e4f262b7..36401188f 100644 --- a/middleware/proxy/setup.go +++ b/middleware/proxy/setup.go @@ -20,7 +20,8 @@ func setup(c *caddy.Controller) error { return middleware.Error("proxy", err) } - P := &Proxy{} + t := dnsserver.GetMiddleware(c, "trace") + P := &Proxy{Trace: t} dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler { P.Next = next P.Upstreams = &upstreams |