diff options
Diffstat (limited to 'request')
-rw-r--r-- | request/request.go | 26 | ||||
-rw-r--r-- | request/request_test.go | 25 | ||||
-rw-r--r-- | request/writer.go | 20 |
3 files changed, 34 insertions, 37 deletions
diff --git a/request/request.go b/request/request.go index bcf6570be..c4e4eea3c 100644 --- a/request/request.go +++ b/request/request.go @@ -226,19 +226,11 @@ func (r *Request) SizeAndDo(m *dns.Msg) bool { return true } -// Result is the result of Scrub. -type Result int - -const ( - // ScrubIgnored is returned when Scrub did nothing to the message. - ScrubIgnored Result = iota - // ScrubExtra is returned when the reply has been scrubbed by removing RRs from the additional section. - ScrubExtra - // ScrubAnswer is returned when the reply has been scrubbed by removing RRs from the answer section. - ScrubAnswer -) +// Scrub is a noop function, added for backwards compatibility reasons. The original Scrub is now called +// automatically by the server on writing the reply. See ScrubWriter. +func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, int) { return reply, 0 } -// Scrub scrubs the reply message so that it will fit the client's buffer. It will first +// scrub scrubs the reply message so that it will fit the client's buffer. It will first // check if the reply fits without compression and then *with* compression. // Scrub will then use binary search to find a save cut off point in the additional section. // If even *without* the additional section the reply still doesn't fit we @@ -246,19 +238,19 @@ const ( // we set the TC bit on the reply; indicating the client should retry over TCP. // Note, the TC bit will be set regardless of protocol, even TCP message will // get the bit, the client should then retry with pigeons. -func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) { +func (r *Request) scrub(reply *dns.Msg) *dns.Msg { size := r.Size() reply.Compress = false rl := reply.Len() if size >= rl { - return reply, ScrubIgnored + return reply } reply.Compress = true rl = reply.Len() if size >= rl { - return reply, ScrubIgnored + return reply } // Account for the OPT record that gets added in SizeAndDo(), subtract that length. @@ -298,7 +290,7 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) { if rl < size { r.SizeAndDo(reply) - return reply, ScrubExtra + return reply } ra := len(reply.Answer) @@ -330,7 +322,7 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) { r.SizeAndDo(reply) reply.Truncated = true - return reply, ScrubAnswer + return reply } // Type returns the type of the question as a string. If the request is malformed the empty string is returned. diff --git a/request/request_test.go b/request/request_test.go index cad2dce0d..6685ad3b3 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -73,10 +73,7 @@ func TestRequestScrubAnswer(t *testing.T) { fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) } - _, got := req.Scrub(reply) - if want := ScrubAnswer; want != got { - t.Errorf("Want scrub result %d, got %d", want, got) - } + req.scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -97,10 +94,7 @@ func TestRequestScrubExtra(t *testing.T) { fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) } - _, got := req.Scrub(reply) - if want := ScrubExtra; want != got { - t.Errorf("Want scrub result %d, got %d", want, got) - } + req.scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -122,10 +116,7 @@ func TestRequestScrubExtraEdns0(t *testing.T) { fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) } - _, got := req.Scrub(reply) - if want := ScrubExtra; want != got { - t.Errorf("Want scrub result %d, got %d", want, got) - } + req.scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -155,10 +146,7 @@ func TestRequestScrubExtraRegression(t *testing.T) { fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i))) } - _, got := req.Scrub(reply) - if want := ScrubExtra; want != got { - t.Errorf("Want scrub result %d, got %d", want, got) - } + reply = req.scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -183,10 +171,7 @@ func TestRequestScrubAnswerExact(t *testing.T) { reply.Answer = append(reply.Answer, test.A(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i))) } - _, got := req.Scrub(reply) - if want := ScrubAnswer; want != got { - t.Errorf("Want scrub result %d, got %d", want, got) - } + req.scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } diff --git a/request/writer.go b/request/writer.go new file mode 100644 index 000000000..ef0c14417 --- /dev/null +++ b/request/writer.go @@ -0,0 +1,20 @@ +package request + +import "github.com/miekg/dns" + +// ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer. +type ScrubWriter struct { + dns.ResponseWriter + req *dns.Msg // original request +} + +// NewScrubWriter returns a new and initialized ScrubWriter. +func NewScrubWriter(req *dns.Msg, w dns.ResponseWriter) *ScrubWriter { return &ScrubWriter{w, req} } + +// WriteMsg overrides the default implementation of the underlaying dns.ResponseWriter and calls +// scrub on the message m and will then write it to the client. +func (s *ScrubWriter) WriteMsg(m *dns.Msg) error { + state := Request{Req: s.req, W: s.ResponseWriter} + new, _ := state.Scrub(m) + return s.ResponseWriter.WriteMsg(new) +} |