diff options
Diffstat (limited to 'core/dnsserver')
-rw-r--r-- | core/dnsserver/server.go | 11 | ||||
-rw-r--r-- | core/dnsserver/server_grpc.go | 1 | ||||
-rw-r--r-- | core/dnsserver/server_https.go | 1 | ||||
-rw-r--r-- | core/dnsserver/server_tls.go | 1 |
4 files changed, 12 insertions, 2 deletions
diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index eb23346e0..c7304d763 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -110,6 +110,7 @@ func (s *Server) Serve(l net.Listener) error { s.m.Lock() s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { ctx := context.WithValue(context.Background(), Key{}, s) + ctx = context.WithValue(ctx, LoopKey{}, 0) s.ServeDNS(ctx, w, r) })} s.m.Unlock() @@ -123,6 +124,7 @@ func (s *Server) ServePacket(p net.PacketConn) error { s.m.Lock() s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { ctx := context.WithValue(context.Background(), Key{}, s) + ctx = context.WithValue(ctx, LoopKey{}, 0) s.ServeDNS(ctx, w, r) })} s.m.Unlock() @@ -347,8 +349,13 @@ const ( udp = 1 ) -// Key is the context key for the current server added to the context. -type Key struct{} +type ( + // Key is the context key for the current server added to the context. + Key struct{} + + // LoopKey is the context key to detect server wide loops. + LoopKey struct{} +) // EnableChaos is a map with plugin names for which we should open CH class queries as we block these by default. var EnableChaos = map[string]struct{}{ diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go index 7873a47ad..37cc237b7 100644 --- a/core/dnsserver/server_grpc.go +++ b/core/dnsserver/server_grpc.go @@ -134,6 +134,7 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg} dnsCtx := context.WithValue(ctx, Key{}, s.Server) + dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0) s.ServeDNS(dnsCtx, w, msg) packed, err := w.Msg.Pack() diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index 057dac49c..7292311e8 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -145,6 +145,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { // We just call the normal chain handler - all error handling is done there. // We should expect a packet to be returned that we can send to the client. ctx := context.WithValue(context.Background(), Key{}, s.Server) + ctx = context.WithValue(ctx, LoopKey{}, 0) s.ServeDNS(ctx, dw, msg) // See section 4.2.1 of RFC 8484. diff --git a/core/dnsserver/server_tls.go b/core/dnsserver/server_tls.go index 3f45e1568..1c53c4e3c 100644 --- a/core/dnsserver/server_tls.go +++ b/core/dnsserver/server_tls.go @@ -50,6 +50,7 @@ func (s *ServerTLS) Serve(l net.Listener) error { // Only fill out the TCP server for this one. s.server[tcp] = &dns.Server{Listener: l, Net: "tcp-tls", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { ctx := context.WithValue(context.Background(), Key{}, s.Server) + ctx = context.WithValue(ctx, LoopKey{}, 0) s.ServeDNS(ctx, w, r) })} s.m.Unlock() |