aboutsummaryrefslogtreecommitdiff
path: root/core/dnsserver
diff options
context:
space:
mode:
Diffstat (limited to 'core/dnsserver')
-rw-r--r--core/dnsserver/server.go39
-rw-r--r--core/dnsserver/server_test.go15
2 files changed, 50 insertions, 4 deletions
diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go
index a08802204..f10bdb78b 100644
--- a/core/dnsserver/server.go
+++ b/core/dnsserver/server.go
@@ -95,7 +95,7 @@ func NewServer(addr string, group []*Config) (*Server, error) {
func (s *Server) Serve(l net.Listener) error {
s.m.Lock()
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
- ctx := context.Background()
+ ctx := context.WithValue(context.Background(), Key{}, s)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
@@ -108,7 +108,7 @@ func (s *Server) Serve(l net.Listener) error {
func (s *Server) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
- ctx := context.Background()
+ ctx := context.WithValue(context.Background(), Key{}, s)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
@@ -207,6 +207,12 @@ func (s *Server) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
return
}
+ ctx, err := incrementDepthAndCheck(ctx)
+ if err != nil {
+ DefaultErrorFunc(w, r, dns.RcodeServerFailure)
+ return
+ }
+
q := r.Question[0].Name
b := make([]byte, len(q))
var off int
@@ -329,11 +335,36 @@ func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rc int) {
w.WriteMsg(answer)
}
+// incrementDepthAndCheck increments the loop counter in the context, and returns an error if
+// the counter exceeds the max number of re-entries
+func incrementDepthAndCheck(ctx context.Context) (context.Context, error) {
+ // Loop counter for self directed lookups
+ loop := ctx.Value(loopKey{})
+ if loop == nil {
+ ctx = context.WithValue(ctx, loopKey{}, 0)
+ return ctx, nil
+ }
+
+ iloop := loop.(int) + 1
+ if iloop > maxreentries {
+ return ctx, fmt.Errorf("too deep")
+ }
+ ctx = context.WithValue(ctx, loopKey{}, iloop)
+ return ctx, nil
+}
+
const (
- tcp = 0
- udp = 1
+ tcp = 0
+ udp = 1
+ maxreentries = 10
)
+// Key is the context key for the current server
+type Key struct{}
+
+// loopKey is the context key for counting self loops
+type loopKey struct{}
+
// enableChaos is a map with plugin names for which we should open CH class queries as
// we block these by default.
var enableChaos = map[string]bool{
diff --git a/core/dnsserver/server_test.go b/core/dnsserver/server_test.go
index a7e663399..e7986c397 100644
--- a/core/dnsserver/server_test.go
+++ b/core/dnsserver/server_test.go
@@ -48,6 +48,21 @@ func TestNewServer(t *testing.T) {
}
}
+func TestIncrementDepthAndCheck(t *testing.T) {
+ ctx := context.Background()
+ var err error
+ for i := 0; i <= maxreentries; i++ {
+ ctx, err = incrementDepthAndCheck(ctx)
+ if err != nil {
+ t.Errorf("Expected no error for depthCheck (i=%v), got %s", i, err)
+ }
+ }
+ _, err = incrementDepthAndCheck(ctx)
+ if err == nil {
+ t.Errorf("Expected error for depthCheck (i=%v)", maxreentries+1)
+ }
+}
+
func BenchmarkCoreServeDNS(b *testing.B) {
s, err := NewServer("127.0.0.1:53", []*Config{testConfig("dns", testPlugin{})})
if err != nil {