aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--plugin/pkg/proxy/connect.go38
-rw-r--r--plugin/pkg/proxy/proxy_test.go95
2 files changed, 133 insertions, 0 deletions
diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go
index f1cc481a1..27385a467 100644
--- a/plugin/pkg/proxy/connect.go
+++ b/plugin/pkg/proxy/connect.go
@@ -6,8 +6,10 @@ package proxy
import (
"context"
+ "errors"
"io"
"strconv"
+ "strings"
"sync/atomic"
"time"
@@ -117,6 +119,18 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options
for {
ret, err = pc.c.ReadMsg()
if err != nil {
+ if ret != nil && (state.Req.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
+ // For UDP, if the error is an overflow, we probably have an upstream misbehaving in some way.
+ // (e.g. sending >512 byte responses without an eDNS0 OPT RR).
+ // Instead of returning an error, return an empty response with TC bit set. This will make the
+ // client retry over TCP (if that's supported) or at least receive a clean
+ // error. The connection is still good so we break before the close.
+
+ // Truncate the response.
+ ret = truncateResponse(ret)
+ break
+ }
+
pc.c.Close() // not giving it back
if err == io.EOF && cached {
return nil, ErrCachedClosed
@@ -148,3 +162,27 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options
}
const cumulativeAvgWeight = 4
+
+// Function to determine if a response should be truncated.
+func shouldTruncateResponse(err error) bool {
+ // This is to handle a scenario in which upstream sets the TC bit, but doesn't truncate the response
+ // and we get ErrBuf instead of overflow.
+ if _, isDNSErr := err.(*dns.Error); isDNSErr && errors.Is(err, dns.ErrBuf) {
+ return true
+ } else if strings.Contains(err.Error(), "overflow") {
+ return true
+ }
+ return false
+}
+
+// Function to return an empty response with TC (truncated) bit set.
+func truncateResponse(response *dns.Msg) *dns.Msg {
+ // Clear out Answer, Extra, and Ns sections
+ response.Answer = nil
+ response.Extra = nil
+ response.Ns = nil
+
+ // Set TC bit to indicate truncation.
+ response.Truncated = true
+ return response
+}
diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go
index 33a7170c0..03d10ce5f 100644
--- a/plugin/pkg/proxy/proxy_test.go
+++ b/plugin/pkg/proxy/proxy_test.go
@@ -3,6 +3,7 @@ package proxy
import (
"context"
"crypto/tls"
+ "errors"
"math"
"testing"
"time"
@@ -128,3 +129,97 @@ func TestProxyIncrementFails(t *testing.T) {
})
}
}
+
+func TestCoreDNSOverflow(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+
+ answers := []dns.RR{
+ test.A("example.org. IN A 127.0.0.1"),
+ test.A("example.org. IN A 127.0.0.2"),
+ test.A("example.org. IN A 127.0.0.3"),
+ test.A("example.org. IN A 127.0.0.4"),
+ test.A("example.org. IN A 127.0.0.5"),
+ test.A("example.org. IN A 127.0.0.6"),
+ test.A("example.org. IN A 127.0.0.7"),
+ test.A("example.org. IN A 127.0.0.8"),
+ test.A("example.org. IN A 127.0.0.9"),
+ test.A("example.org. IN A 127.0.0.10"),
+ test.A("example.org. IN A 127.0.0.11"),
+ test.A("example.org. IN A 127.0.0.12"),
+ test.A("example.org. IN A 127.0.0.13"),
+ test.A("example.org. IN A 127.0.0.14"),
+ test.A("example.org. IN A 127.0.0.15"),
+ test.A("example.org. IN A 127.0.0.16"),
+ test.A("example.org. IN A 127.0.0.17"),
+ test.A("example.org. IN A 127.0.0.18"),
+ test.A("example.org. IN A 127.0.0.19"),
+ test.A("example.org. IN A 127.0.0.20"),
+ }
+ ret.Answer = answers
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy("TestCoreDNSOverflow", s.Addr, transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+ p.Start(5 * time.Second)
+ defer p.Stop()
+
+ // Test different connection modes
+ testConnection := func(proto string, options Options, expectTruncated bool) {
+ t.Helper()
+
+ queryMsg := new(dns.Msg)
+ queryMsg.SetQuestion("example.org.", dns.TypeA)
+
+ recorder := dnstest.NewRecorder(&test.ResponseWriter{})
+ request := request.Request{Req: queryMsg, W: recorder}
+
+ response, err := p.Connect(context.Background(), request, options)
+ if err != nil {
+ t.Errorf("Failed to connect to testdnsserver: %s", err)
+ }
+
+ if response.Truncated != expectTruncated {
+ t.Errorf("Expected truncated response for %s, but got TC flag %v", proto, response.Truncated)
+ }
+ }
+
+ // Test PreferUDP, expect truncated response
+ testConnection("PreferUDP", Options{PreferUDP: true}, true)
+
+ // Test ForceTCP, expect no truncated response
+ testConnection("ForceTCP", Options{ForceTCP: true}, false)
+
+ // Test No options specified, expect truncated response
+ testConnection("NoOptionsSpecified", Options{}, true)
+
+ // Test both TCP and UDP provided, expect no truncated response
+ testConnection("BothTCPAndUDP", Options{PreferUDP: true, ForceTCP: true}, false)
+}
+
+func TestShouldTruncateResponse(t *testing.T) {
+ testCases := []struct {
+ testname string
+ err error
+ expected bool
+ }{
+ {"BadAlgorithm", dns.ErrAlg, false},
+ {"BufferSizeTooSmall", dns.ErrBuf, true},
+ {"OverflowUnpackingA", errors.New("overflow unpacking a"), true},
+ {"OverflowingHeaderSize", errors.New("overflowing header size"), true},
+ {"OverflowpackingA", errors.New("overflow packing a"), true},
+ {"ErrSig", dns.ErrSig, false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.testname, func(t *testing.T) {
+ result := shouldTruncateResponse(tc.err)
+ if result != tc.expected {
+ t.Errorf("For testname '%v', expected %v but got %v", tc.testname, tc.expected, result)
+ }
+ })
+ }
+}