aboutsummaryrefslogtreecommitdiff
path: root/request
diff options
context:
space:
mode:
Diffstat (limited to 'request')
-rw-r--r--request/request.go26
-rw-r--r--request/request_test.go25
-rw-r--r--request/writer.go20
3 files changed, 34 insertions, 37 deletions
diff --git a/request/request.go b/request/request.go
index bcf6570be..c4e4eea3c 100644
--- a/request/request.go
+++ b/request/request.go
@@ -226,19 +226,11 @@ func (r *Request) SizeAndDo(m *dns.Msg) bool {
return true
}
-// Result is the result of Scrub.
-type Result int
-
-const (
- // ScrubIgnored is returned when Scrub did nothing to the message.
- ScrubIgnored Result = iota
- // ScrubExtra is returned when the reply has been scrubbed by removing RRs from the additional section.
- ScrubExtra
- // ScrubAnswer is returned when the reply has been scrubbed by removing RRs from the answer section.
- ScrubAnswer
-)
+// Scrub is a noop function, added for backwards compatibility reasons. The original Scrub is now called
+// automatically by the server on writing the reply. See ScrubWriter.
+func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, int) { return reply, 0 }
-// Scrub scrubs the reply message so that it will fit the client's buffer. It will first
+// scrub scrubs the reply message so that it will fit the client's buffer. It will first
// check if the reply fits without compression and then *with* compression.
// Scrub will then use binary search to find a save cut off point in the additional section.
// If even *without* the additional section the reply still doesn't fit we
@@ -246,19 +238,19 @@ const (
// we set the TC bit on the reply; indicating the client should retry over TCP.
// Note, the TC bit will be set regardless of protocol, even TCP message will
// get the bit, the client should then retry with pigeons.
-func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
+func (r *Request) scrub(reply *dns.Msg) *dns.Msg {
size := r.Size()
reply.Compress = false
rl := reply.Len()
if size >= rl {
- return reply, ScrubIgnored
+ return reply
}
reply.Compress = true
rl = reply.Len()
if size >= rl {
- return reply, ScrubIgnored
+ return reply
}
// Account for the OPT record that gets added in SizeAndDo(), subtract that length.
@@ -298,7 +290,7 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
if rl < size {
r.SizeAndDo(reply)
- return reply, ScrubExtra
+ return reply
}
ra := len(reply.Answer)
@@ -330,7 +322,7 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
r.SizeAndDo(reply)
reply.Truncated = true
- return reply, ScrubAnswer
+ return reply
}
// Type returns the type of the question as a string. If the request is malformed the empty string is returned.
diff --git a/request/request_test.go b/request/request_test.go
index cad2dce0d..6685ad3b3 100644
--- a/request/request_test.go
+++ b/request/request_test.go
@@ -73,10 +73,7 @@ func TestRequestScrubAnswer(t *testing.T) {
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
- _, got := req.Scrub(reply)
- if want := ScrubAnswer; want != got {
- t.Errorf("Want scrub result %d, got %d", want, got)
- }
+ req.scrub(reply)
if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
}
@@ -97,10 +94,7 @@ func TestRequestScrubExtra(t *testing.T) {
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
- _, got := req.Scrub(reply)
- if want := ScrubExtra; want != got {
- t.Errorf("Want scrub result %d, got %d", want, got)
- }
+ req.scrub(reply)
if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
}
@@ -122,10 +116,7 @@ func TestRequestScrubExtraEdns0(t *testing.T) {
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
- _, got := req.Scrub(reply)
- if want := ScrubExtra; want != got {
- t.Errorf("Want scrub result %d, got %d", want, got)
- }
+ req.scrub(reply)
if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
}
@@ -155,10 +146,7 @@ func TestRequestScrubExtraRegression(t *testing.T) {
fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i)))
}
- _, got := req.Scrub(reply)
- if want := ScrubExtra; want != got {
- t.Errorf("Want scrub result %d, got %d", want, got)
- }
+ reply = req.scrub(reply)
if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
}
@@ -183,10 +171,7 @@ func TestRequestScrubAnswerExact(t *testing.T) {
reply.Answer = append(reply.Answer, test.A(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i)))
}
- _, got := req.Scrub(reply)
- if want := ScrubAnswer; want != got {
- t.Errorf("Want scrub result %d, got %d", want, got)
- }
+ req.scrub(reply)
if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
}
diff --git a/request/writer.go b/request/writer.go
new file mode 100644
index 000000000..ef0c14417
--- /dev/null
+++ b/request/writer.go
@@ -0,0 +1,20 @@
+package request
+
+import "github.com/miekg/dns"
+
+// ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer.
+type ScrubWriter struct {
+ dns.ResponseWriter
+ req *dns.Msg // original request
+}
+
+// NewScrubWriter returns a new and initialized ScrubWriter.
+func NewScrubWriter(req *dns.Msg, w dns.ResponseWriter) *ScrubWriter { return &ScrubWriter{w, req} }
+
+// WriteMsg overrides the default implementation of the underlaying dns.ResponseWriter and calls
+// scrub on the message m and will then write it to the client.
+func (s *ScrubWriter) WriteMsg(m *dns.Msg) error {
+ state := Request{Req: s.req, W: s.ResponseWriter}
+ new, _ := state.Scrub(m)
+ return s.ResponseWriter.WriteMsg(new)
+}