aboutsummaryrefslogtreecommitdiff
path: root/core/dnsserver
diff options
context:
space:
mode:
Diffstat (limited to 'core/dnsserver')
-rw-r--r--core/dnsserver/server.go11
-rw-r--r--core/dnsserver/server_grpc.go1
-rw-r--r--core/dnsserver/server_https.go1
-rw-r--r--core/dnsserver/server_tls.go1
4 files changed, 12 insertions, 2 deletions
diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go
index eb23346e0..c7304d763 100644
--- a/core/dnsserver/server.go
+++ b/core/dnsserver/server.go
@@ -110,6 +110,7 @@ func (s *Server) Serve(l net.Listener) error {
s.m.Lock()
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
+ ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
@@ -123,6 +124,7 @@ func (s *Server) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
+ ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
@@ -347,8 +349,13 @@ const (
udp = 1
)
-// Key is the context key for the current server added to the context.
-type Key struct{}
+type (
+ // Key is the context key for the current server added to the context.
+ Key struct{}
+
+ // LoopKey is the context key to detect server wide loops.
+ LoopKey struct{}
+)
// EnableChaos is a map with plugin names for which we should open CH class queries as we block these by default.
var EnableChaos = map[string]struct{}{
diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go
index 7873a47ad..37cc237b7 100644
--- a/core/dnsserver/server_grpc.go
+++ b/core/dnsserver/server_grpc.go
@@ -134,6 +134,7 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket
w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg}
dnsCtx := context.WithValue(ctx, Key{}, s.Server)
+ dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, msg)
packed, err := w.Msg.Pack()
diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go
index 057dac49c..7292311e8 100644
--- a/core/dnsserver/server_https.go
+++ b/core/dnsserver/server_https.go
@@ -145,6 +145,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// We just call the normal chain handler - all error handling is done there.
// We should expect a packet to be returned that we can send to the client.
ctx := context.WithValue(context.Background(), Key{}, s.Server)
+ ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, dw, msg)
// See section 4.2.1 of RFC 8484.
diff --git a/core/dnsserver/server_tls.go b/core/dnsserver/server_tls.go
index 3f45e1568..1c53c4e3c 100644
--- a/core/dnsserver/server_tls.go
+++ b/core/dnsserver/server_tls.go
@@ -50,6 +50,7 @@ func (s *ServerTLS) Serve(l net.Listener) error {
// Only fill out the TCP server for this one.
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp-tls", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s.Server)
+ ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()