diff options
Diffstat (limited to 'plugin/proxy/grpc.go')
-rw-r--r-- | plugin/proxy/grpc.go | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/plugin/proxy/grpc.go b/plugin/proxy/grpc.go new file mode 100644 index 000000000..f98fd2e91 --- /dev/null +++ b/plugin/proxy/grpc.go @@ -0,0 +1,96 @@ +package proxy + +import ( + "context" + "crypto/tls" + "log" + + "github.com/coredns/coredns/pb" + "github.com/coredns/coredns/plugin/pkg/trace" + "github.com/coredns/coredns/request" + + "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" + "github.com/miekg/dns" + opentracing "github.com/opentracing/opentracing-go" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +type grpcClient struct { + dialOpts []grpc.DialOption + clients map[string]pb.DnsServiceClient + conns []*grpc.ClientConn + upstream *staticUpstream +} + +func newGrpcClient(tls *tls.Config, u *staticUpstream) *grpcClient { + g := &grpcClient{upstream: u} + + if tls == nil { + g.dialOpts = append(g.dialOpts, grpc.WithInsecure()) + } else { + g.dialOpts = append(g.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tls))) + } + g.clients = map[string]pb.DnsServiceClient{} + + return g +} + +func (g *grpcClient) Exchange(ctx context.Context, addr string, state request.Request) (*dns.Msg, error) { + msg, err := state.Req.Pack() + if err != nil { + return nil, err + } + + reply, err := g.clients[addr].Query(ctx, &pb.DnsPacket{Msg: msg}) + if err != nil { + return nil, err + } + d := new(dns.Msg) + err = d.Unpack(reply.Msg) + if err != nil { + return nil, err + } + return d, nil +} + +func (g *grpcClient) Transport() string { return "tcp" } + +func (g *grpcClient) Protocol() string { return "grpc" } + +func (g *grpcClient) OnShutdown(p *Proxy) error { + g.clients = map[string]pb.DnsServiceClient{} + for i, conn := range g.conns { + err := conn.Close() + if err != nil { + log.Printf("[WARNING] Error closing connection %d: %s\n", i, err) + } + } + g.conns = []*grpc.ClientConn{} + return nil +} + +func (g *grpcClient) OnStartup(p *Proxy) error { + dialOpts := g.dialOpts + if p.Trace != nil { + if t, ok := p.Trace.(trace.Trace); ok { + onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}) bool { + return parentSpanCtx != nil + } + intercept := otgrpc.OpenTracingClientInterceptor(t.Tracer(), otgrpc.IncludingSpans(onlyIfParent)) + dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(intercept)) + } else { + log.Printf("[WARNING] Wrong type for trace plugin reference: %s", p.Trace) + } + } + for _, host := range g.upstream.Hosts { + conn, err := grpc.Dial(host.Name, dialOpts...) + if err != nil { + log.Printf("[WARNING] Skipping gRPC host '%s' due to Dial error: %s\n", host.Name, err) + } else { + g.clients[host.Name] = pb.NewDnsServiceClient(conn) + g.conns = append(g.conns, conn) + } + } + return nil +} |