diff options
-rw-r--r-- | core/dnsserver/server-grpc.go | 39 | ||||
-rw-r--r-- | core/dnsserver/server-tls.go | 34 |
2 files changed, 29 insertions, 44 deletions
diff --git a/core/dnsserver/server-grpc.go b/core/dnsserver/server-grpc.go index da6910c4b..ba9519cdb 100644 --- a/core/dnsserver/server-grpc.go +++ b/core/dnsserver/server-grpc.go @@ -8,7 +8,7 @@ import ( "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" "github.com/miekg/dns" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/peer" @@ -20,19 +20,25 @@ import ( type ServergRPC struct { *Server grpcServer *grpc.Server - listenAddr net.Addr + tlsConfig *tls.Config } // NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it. func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) { - s, err := NewServer(addr, group) if err != nil { return nil, err } - gs := &ServergRPC{Server: s} - return gs, nil + // The *tls* plugin must make sure that multiple conflicting + // TLS configuration return an error: it can only be specified once. + var tlsConfig *tls.Config + for _, conf := range s.zones { + // Should we error if some configs *don't* have TLS? + tlsConfig = conf.TLSConfig + } + + return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil } // Serve implements caddy.TCPServer interface. @@ -53,6 +59,9 @@ func (s *ServergRPC) Serve(l net.Listener) error { pb.RegisterDnsServiceServer(s.grpcServer, s) + if s.tlsConfig != nil { + l = tls.NewListener(l, s.tlsConfig) + } return s.grpcServer.Serve(l) } @@ -62,25 +71,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil } // Listen implements caddy.TCPServer interface. func (s *ServergRPC) Listen() (net.Listener, error) { - // The *tls* plugin must make sure that multiple conflicting - // TLS configuration return an error: it can only be specified once. - tlsConfig := new(tls.Config) - for _, conf := range s.zones { - // Should we error if some configs *don't* have TLS? - tlsConfig = conf.TLSConfig - } - - var ( - l net.Listener - err error - ) - - if tlsConfig == nil { - l, err = net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):]) - } else { - l, err = tls.Listen("tcp", s.Addr[len(TransportGRPC+"://"):], tlsConfig) - } - + l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):]) if err != nil { return nil, err } diff --git a/core/dnsserver/server-tls.go b/core/dnsserver/server-tls.go index 2880b0183..ce3ab8185 100644 --- a/core/dnsserver/server-tls.go +++ b/core/dnsserver/server-tls.go @@ -12,6 +12,7 @@ import ( // ServerTLS represents an instance of a TLS-over-DNS-server. type ServerTLS struct { *Server + tlsConfig *tls.Config } // NewServerTLS returns a new CoreDNS TLS server and compiles all plugin in to it. @@ -20,14 +21,25 @@ func NewServerTLS(addr string, group []*Config) (*ServerTLS, error) { if err != nil { return nil, err } + // The *tls* plugin must make sure that multiple conflicting + // TLS configuration return an error: it can only be specified once. + var tlsConfig *tls.Config + for _, conf := range s.zones { + // Should we error if some configs *don't* have TLS? + tlsConfig = conf.TLSConfig + } - return &ServerTLS{Server: s}, nil + return &ServerTLS{Server: s, tlsConfig: tlsConfig}, nil } // Serve implements caddy.TCPServer interface. func (s *ServerTLS) Serve(l net.Listener) error { s.m.Lock() + if s.tlsConfig != nil { + l = tls.NewListener(l, s.tlsConfig) + } + // Only fill out the TCP server for this one. s.server[tcp] = &dns.Server{Listener: l, Net: "tcp-tls", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { ctx := context.Background() @@ -43,25 +55,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil } // Listen implements caddy.TCPServer interface. func (s *ServerTLS) Listen() (net.Listener, error) { - // The *tls* plugin must make sure that multiple conflicting - // TLS configuration return an error: it can only be specified once. - tlsConfig := new(tls.Config) - for _, conf := range s.zones { - // Should we error if some configs *don't* have TLS? - tlsConfig = conf.TLSConfig - } - - var ( - l net.Listener - err error - ) - - if tlsConfig == nil { - l, err = net.Listen("tcp", s.Addr[len(TransportTLS+"://"):]) - } else { - l, err = tls.Listen("tcp", s.Addr[len(TransportTLS+"://"):], tlsConfig) - } - + l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):]) if err != nil { return nil, err } |