diff options
-rw-r--r-- | plugin/pkg/proxy/connect.go | 38 | ||||
-rw-r--r-- | plugin/pkg/proxy/proxy_test.go | 95 |
2 files changed, 133 insertions, 0 deletions
diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go index f1cc481a1..27385a467 100644 --- a/plugin/pkg/proxy/connect.go +++ b/plugin/pkg/proxy/connect.go @@ -6,8 +6,10 @@ package proxy import ( "context" + "errors" "io" "strconv" + "strings" "sync/atomic" "time" @@ -117,6 +119,18 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options for { ret, err = pc.c.ReadMsg() if err != nil { + if ret != nil && (state.Req.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) { + // For UDP, if the error is an overflow, we probably have an upstream misbehaving in some way. + // (e.g. sending >512 byte responses without an eDNS0 OPT RR). + // Instead of returning an error, return an empty response with TC bit set. This will make the + // client retry over TCP (if that's supported) or at least receive a clean + // error. The connection is still good so we break before the close. + + // Truncate the response. + ret = truncateResponse(ret) + break + } + pc.c.Close() // not giving it back if err == io.EOF && cached { return nil, ErrCachedClosed @@ -148,3 +162,27 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options } const cumulativeAvgWeight = 4 + +// Function to determine if a response should be truncated. +func shouldTruncateResponse(err error) bool { + // This is to handle a scenario in which upstream sets the TC bit, but doesn't truncate the response + // and we get ErrBuf instead of overflow. + if _, isDNSErr := err.(*dns.Error); isDNSErr && errors.Is(err, dns.ErrBuf) { + return true + } else if strings.Contains(err.Error(), "overflow") { + return true + } + return false +} + +// Function to return an empty response with TC (truncated) bit set. +func truncateResponse(response *dns.Msg) *dns.Msg { + // Clear out Answer, Extra, and Ns sections + response.Answer = nil + response.Extra = nil + response.Ns = nil + + // Set TC bit to indicate truncation. + response.Truncated = true + return response +} diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go index 33a7170c0..03d10ce5f 100644 --- a/plugin/pkg/proxy/proxy_test.go +++ b/plugin/pkg/proxy/proxy_test.go @@ -3,6 +3,7 @@ package proxy import ( "context" "crypto/tls" + "errors" "math" "testing" "time" @@ -128,3 +129,97 @@ func TestProxyIncrementFails(t *testing.T) { }) } } + +func TestCoreDNSOverflow(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + + answers := []dns.RR{ + test.A("example.org. IN A 127.0.0.1"), + test.A("example.org. IN A 127.0.0.2"), + test.A("example.org. IN A 127.0.0.3"), + test.A("example.org. IN A 127.0.0.4"), + test.A("example.org. IN A 127.0.0.5"), + test.A("example.org. IN A 127.0.0.6"), + test.A("example.org. IN A 127.0.0.7"), + test.A("example.org. IN A 127.0.0.8"), + test.A("example.org. IN A 127.0.0.9"), + test.A("example.org. IN A 127.0.0.10"), + test.A("example.org. IN A 127.0.0.11"), + test.A("example.org. IN A 127.0.0.12"), + test.A("example.org. IN A 127.0.0.13"), + test.A("example.org. IN A 127.0.0.14"), + test.A("example.org. IN A 127.0.0.15"), + test.A("example.org. IN A 127.0.0.16"), + test.A("example.org. IN A 127.0.0.17"), + test.A("example.org. IN A 127.0.0.18"), + test.A("example.org. IN A 127.0.0.19"), + test.A("example.org. IN A 127.0.0.20"), + } + ret.Answer = answers + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy("TestCoreDNSOverflow", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + p.Start(5 * time.Second) + defer p.Stop() + + // Test different connection modes + testConnection := func(proto string, options Options, expectTruncated bool) { + t.Helper() + + queryMsg := new(dns.Msg) + queryMsg.SetQuestion("example.org.", dns.TypeA) + + recorder := dnstest.NewRecorder(&test.ResponseWriter{}) + request := request.Request{Req: queryMsg, W: recorder} + + response, err := p.Connect(context.Background(), request, options) + if err != nil { + t.Errorf("Failed to connect to testdnsserver: %s", err) + } + + if response.Truncated != expectTruncated { + t.Errorf("Expected truncated response for %s, but got TC flag %v", proto, response.Truncated) + } + } + + // Test PreferUDP, expect truncated response + testConnection("PreferUDP", Options{PreferUDP: true}, true) + + // Test ForceTCP, expect no truncated response + testConnection("ForceTCP", Options{ForceTCP: true}, false) + + // Test No options specified, expect truncated response + testConnection("NoOptionsSpecified", Options{}, true) + + // Test both TCP and UDP provided, expect no truncated response + testConnection("BothTCPAndUDP", Options{PreferUDP: true, ForceTCP: true}, false) +} + +func TestShouldTruncateResponse(t *testing.T) { + testCases := []struct { + testname string + err error + expected bool + }{ + {"BadAlgorithm", dns.ErrAlg, false}, + {"BufferSizeTooSmall", dns.ErrBuf, true}, + {"OverflowUnpackingA", errors.New("overflow unpacking a"), true}, + {"OverflowingHeaderSize", errors.New("overflowing header size"), true}, + {"OverflowpackingA", errors.New("overflow packing a"), true}, + {"ErrSig", dns.ErrSig, false}, + } + + for _, tc := range testCases { + t.Run(tc.testname, func(t *testing.T) { + result := shouldTruncateResponse(tc.err) + if result != tc.expected { + t.Errorf("For testname '%v', expected %v but got %v", tc.testname, tc.expected, result) + } + }) + } +} |