aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--middleware/edns.go34
-rw-r--r--middleware/edns_test.go37
-rw-r--r--middleware/rcode.go14
-rw-r--r--middleware/recorder.go20
-rw-r--r--server/server.go17
5 files changed, 99 insertions, 23 deletions
diff --git a/middleware/edns.go b/middleware/edns.go
new file mode 100644
index 000000000..aaab502e0
--- /dev/null
+++ b/middleware/edns.go
@@ -0,0 +1,34 @@
+package middleware
+
+import (
+ "errors"
+
+ "github.com/miekg/dns"
+)
+
+// Edns0Version checks the EDNS version in the request. If error
+// is nil everything is OK and we can invoke the middleware. If non-nil, the
+// returned Msg is valid to be returned to the client (and should). For some
+// reason this response should not contain a question RR in the question section.
+func Edns0Version(req *dns.Msg) (*dns.Msg, error) {
+ opt := req.IsEdns0()
+ if opt == nil {
+ return nil, nil
+ }
+ if opt.Version() == 0 {
+ return nil, nil
+ }
+ m := new(dns.Msg)
+ m.SetReply(req)
+ // zero out question section, wtf.
+ m.Question = nil
+
+ o := new(dns.OPT)
+ o.Hdr.Name = "."
+ o.Hdr.Rrtype = dns.TypeOPT
+ o.SetVersion(0)
+ o.SetExtendedRcode(dns.RcodeBadVers)
+ m.Extra = []dns.RR{o}
+
+ return m, errors.New("EDNS0 BADVERS")
+}
diff --git a/middleware/edns_test.go b/middleware/edns_test.go
new file mode 100644
index 000000000..7b4e6fc66
--- /dev/null
+++ b/middleware/edns_test.go
@@ -0,0 +1,37 @@
+package middleware
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestEdns0Version(t *testing.T) {
+ m := ednsMsg()
+ m.Extra[0].(*dns.OPT).SetVersion(2)
+
+ _, err := Edns0Version(m)
+ if err == nil {
+ t.Errorf("expected wrong version, but got OK")
+ }
+}
+
+func TestEdns0VersionNoEdns(t *testing.T) {
+ m := ednsMsg()
+ m.Extra = nil
+
+ _, err := Edns0Version(m)
+ if err != nil {
+ t.Errorf("expected no error, but got one: %s", err)
+ }
+}
+
+func ednsMsg() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion("example.com.", dns.TypeA)
+ o := new(dns.OPT)
+ o.Hdr.Name = "."
+ o.Hdr.Rrtype = dns.TypeOPT
+ m.Extra = append(m.Extra, o)
+ return m
+}
diff --git a/middleware/rcode.go b/middleware/rcode.go
new file mode 100644
index 000000000..989f90fdd
--- /dev/null
+++ b/middleware/rcode.go
@@ -0,0 +1,14 @@
+package middleware
+
+import (
+ "strconv"
+
+ "github.com/miekg/dns"
+)
+
+func RcodeToString(rcode int) string {
+ if str, ok := dns.RcodeToString[rcode]; ok {
+ return str
+ }
+ return "RCODE" + strconv.Itoa(rcode)
+}
diff --git a/middleware/recorder.go b/middleware/recorder.go
index feede34ae..d1e466ec3 100644
--- a/middleware/recorder.go
+++ b/middleware/recorder.go
@@ -1,7 +1,6 @@
package middleware
import (
- "strconv"
"time"
"github.com/miekg/dns"
@@ -54,27 +53,16 @@ func (r *ResponseRecorder) Write(buf []byte) (int, error) {
}
// Size returns the size.
-func (r *ResponseRecorder) Size() int {
- return r.size
-}
+func (r *ResponseRecorder) Size() int { return r.size }
// Rcode returns the rcode.
-func (r *ResponseRecorder) Rcode() string {
- if rcode, ok := dns.RcodeToString[r.rcode]; ok {
- return rcode
- }
- return "RCODE" + strconv.Itoa(r.rcode)
-}
+func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) }
// Start returns the start time of the ResponseRecorder.
-func (r *ResponseRecorder) Start() time.Time {
- return r.start
-}
+func (r *ResponseRecorder) Start() time.Time { return r.start }
// Msg returns the written message from the ResponseRecorder.
-func (r *ResponseRecorder) Msg() *dns.Msg {
- return r.msg
-}
+func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg }
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
diff --git a/server/server.go b/server/server.go
index 67cc35ba5..7ea931daa 100644
--- a/server/server.go
+++ b/server/server.go
@@ -12,10 +12,10 @@ import (
"net"
"os"
"runtime"
- "strconv"
"sync"
"time"
+ "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/chaos"
"github.com/miekg/coredns/middleware/prometheus"
@@ -279,6 +279,14 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}()
+ if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once.
+ qtype := dns.Type(r.Question[0].Qtype).String()
+ rc := middleware.RcodeToString(dns.RcodeBadVers)
+ metrics.Report(dropped, qtype, rc, m.Len(), time.Now())
+ w.WriteMsg(m)
+ return
+ }
+
// Execute the optional request callback if it exists
if s.ReqCallback != nil && s.ReqCallback(w, r) {
return
@@ -332,12 +340,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// of the specified HTTP status code.
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
qtype := dns.Type(r.Question[0].Qtype).String()
-
- // this code is duplicated a few times, TODO(miek)
- rc := dns.RcodeToString[rcode]
- if rc == "" {
- rc = "RCODE" + strconv.Itoa(rcode)
- }
+ rc := middleware.RcodeToString(rcode)
answer := new(dns.Msg)
answer.SetRcode(r, rcode)