diff options
Diffstat (limited to 'core/dnsserver/server.go')
-rw-r--r-- | core/dnsserver/server.go | 39 |
1 files changed, 35 insertions, 4 deletions
diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index a08802204..f10bdb78b 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -95,7 +95,7 @@ func NewServer(addr string, group []*Config) (*Server, error) { func (s *Server) Serve(l net.Listener) error { s.m.Lock() s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { - ctx := context.Background() + ctx := context.WithValue(context.Background(), Key{}, s) s.ServeDNS(ctx, w, r) })} s.m.Unlock() @@ -108,7 +108,7 @@ func (s *Server) Serve(l net.Listener) error { func (s *Server) ServePacket(p net.PacketConn) error { s.m.Lock() s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { - ctx := context.Background() + ctx := context.WithValue(context.Background(), Key{}, s) s.ServeDNS(ctx, w, r) })} s.m.Unlock() @@ -207,6 +207,12 @@ func (s *Server) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) return } + ctx, err := incrementDepthAndCheck(ctx) + if err != nil { + DefaultErrorFunc(w, r, dns.RcodeServerFailure) + return + } + q := r.Question[0].Name b := make([]byte, len(q)) var off int @@ -329,11 +335,36 @@ func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rc int) { w.WriteMsg(answer) } +// incrementDepthAndCheck increments the loop counter in the context, and returns an error if +// the counter exceeds the max number of re-entries +func incrementDepthAndCheck(ctx context.Context) (context.Context, error) { + // Loop counter for self directed lookups + loop := ctx.Value(loopKey{}) + if loop == nil { + ctx = context.WithValue(ctx, loopKey{}, 0) + return ctx, nil + } + + iloop := loop.(int) + 1 + if iloop > maxreentries { + return ctx, fmt.Errorf("too deep") + } + ctx = context.WithValue(ctx, loopKey{}, iloop) + return ctx, nil +} + const ( - tcp = 0 - udp = 1 + tcp = 0 + udp = 1 + maxreentries = 10 ) +// Key is the context key for the current server +type Key struct{} + +// loopKey is the context key for counting self loops +type loopKey struct{} + // enableChaos is a map with plugin names for which we should open CH class queries as // we block these by default. var enableChaos = map[string]bool{ |