aboutsummaryrefslogtreecommitdiff
path: root/core/dnsserver/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'core/dnsserver/server.go')
-rw-r--r--core/dnsserver/server.go39
1 files changed, 35 insertions, 4 deletions
diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go
index a08802204..f10bdb78b 100644
--- a/core/dnsserver/server.go
+++ b/core/dnsserver/server.go
@@ -95,7 +95,7 @@ func NewServer(addr string, group []*Config) (*Server, error) {
func (s *Server) Serve(l net.Listener) error {
s.m.Lock()
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
- ctx := context.Background()
+ ctx := context.WithValue(context.Background(), Key{}, s)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
@@ -108,7 +108,7 @@ func (s *Server) Serve(l net.Listener) error {
func (s *Server) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
- ctx := context.Background()
+ ctx := context.WithValue(context.Background(), Key{}, s)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
@@ -207,6 +207,12 @@ func (s *Server) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
return
}
+ ctx, err := incrementDepthAndCheck(ctx)
+ if err != nil {
+ DefaultErrorFunc(w, r, dns.RcodeServerFailure)
+ return
+ }
+
q := r.Question[0].Name
b := make([]byte, len(q))
var off int
@@ -329,11 +335,36 @@ func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rc int) {
w.WriteMsg(answer)
}
+// incrementDepthAndCheck increments the loop counter in the context, and returns an error if
+// the counter exceeds the max number of re-entries
+func incrementDepthAndCheck(ctx context.Context) (context.Context, error) {
+ // Loop counter for self directed lookups
+ loop := ctx.Value(loopKey{})
+ if loop == nil {
+ ctx = context.WithValue(ctx, loopKey{}, 0)
+ return ctx, nil
+ }
+
+ iloop := loop.(int) + 1
+ if iloop > maxreentries {
+ return ctx, fmt.Errorf("too deep")
+ }
+ ctx = context.WithValue(ctx, loopKey{}, iloop)
+ return ctx, nil
+}
+
const (
- tcp = 0
- udp = 1
+ tcp = 0
+ udp = 1
+ maxreentries = 10
)
+// Key is the context key for the current server
+type Key struct{}
+
+// loopKey is the context key for counting self loops
+type loopKey struct{}
+
// enableChaos is a map with plugin names for which we should open CH class queries as
// we block these by default.
var enableChaos = map[string]bool{