aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/dnsserver/server-grpc.go39
-rw-r--r--core/dnsserver/server-tls.go34
2 files changed, 29 insertions, 44 deletions
diff --git a/core/dnsserver/server-grpc.go b/core/dnsserver/server-grpc.go
index da6910c4b..ba9519cdb 100644
--- a/core/dnsserver/server-grpc.go
+++ b/core/dnsserver/server-grpc.go
@@ -8,7 +8,7 @@ import (
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
- opentracing "github.com/opentracing/opentracing-go"
+ "github.com/opentracing/opentracing-go"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
@@ -20,19 +20,25 @@ import (
type ServergRPC struct {
*Server
grpcServer *grpc.Server
-
listenAddr net.Addr
+ tlsConfig *tls.Config
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
-
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
- gs := &ServergRPC{Server: s}
- return gs, nil
+ // The *tls* plugin must make sure that multiple conflicting
+ // TLS configuration return an error: it can only be specified once.
+ var tlsConfig *tls.Config
+ for _, conf := range s.zones {
+ // Should we error if some configs *don't* have TLS?
+ tlsConfig = conf.TLSConfig
+ }
+
+ return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
}
// Serve implements caddy.TCPServer interface.
@@ -53,6 +59,9 @@ func (s *ServergRPC) Serve(l net.Listener) error {
pb.RegisterDnsServiceServer(s.grpcServer, s)
+ if s.tlsConfig != nil {
+ l = tls.NewListener(l, s.tlsConfig)
+ }
return s.grpcServer.Serve(l)
}
@@ -62,25 +71,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) {
- // The *tls* plugin must make sure that multiple conflicting
- // TLS configuration return an error: it can only be specified once.
- tlsConfig := new(tls.Config)
- for _, conf := range s.zones {
- // Should we error if some configs *don't* have TLS?
- tlsConfig = conf.TLSConfig
- }
-
- var (
- l net.Listener
- err error
- )
-
- if tlsConfig == nil {
- l, err = net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):])
- } else {
- l, err = tls.Listen("tcp", s.Addr[len(TransportGRPC+"://"):], tlsConfig)
- }
-
+ l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):])
if err != nil {
return nil, err
}
diff --git a/core/dnsserver/server-tls.go b/core/dnsserver/server-tls.go
index 2880b0183..ce3ab8185 100644
--- a/core/dnsserver/server-tls.go
+++ b/core/dnsserver/server-tls.go
@@ -12,6 +12,7 @@ import (
// ServerTLS represents an instance of a TLS-over-DNS-server.
type ServerTLS struct {
*Server
+ tlsConfig *tls.Config
}
// NewServerTLS returns a new CoreDNS TLS server and compiles all plugin in to it.
@@ -20,14 +21,25 @@ func NewServerTLS(addr string, group []*Config) (*ServerTLS, error) {
if err != nil {
return nil, err
}
+ // The *tls* plugin must make sure that multiple conflicting
+ // TLS configuration return an error: it can only be specified once.
+ var tlsConfig *tls.Config
+ for _, conf := range s.zones {
+ // Should we error if some configs *don't* have TLS?
+ tlsConfig = conf.TLSConfig
+ }
- return &ServerTLS{Server: s}, nil
+ return &ServerTLS{Server: s, tlsConfig: tlsConfig}, nil
}
// Serve implements caddy.TCPServer interface.
func (s *ServerTLS) Serve(l net.Listener) error {
s.m.Lock()
+ if s.tlsConfig != nil {
+ l = tls.NewListener(l, s.tlsConfig)
+ }
+
// Only fill out the TCP server for this one.
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp-tls", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.Background()
@@ -43,25 +55,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerTLS) Listen() (net.Listener, error) {
- // The *tls* plugin must make sure that multiple conflicting
- // TLS configuration return an error: it can only be specified once.
- tlsConfig := new(tls.Config)
- for _, conf := range s.zones {
- // Should we error if some configs *don't* have TLS?
- tlsConfig = conf.TLSConfig
- }
-
- var (
- l net.Listener
- err error
- )
-
- if tlsConfig == nil {
- l, err = net.Listen("tcp", s.Addr[len(TransportTLS+"://"):])
- } else {
- l, err = tls.Listen("tcp", s.Addr[len(TransportTLS+"://"):], tlsConfig)
- }
-
+ l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):])
if err != nil {
return nil, err
}