aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/dnsserver/https.go51
-rw-r--r--core/dnsserver/https_test.go66
-rw-r--r--core/dnsserver/server_https.go21
-rw-r--r--plugin/pkg/doh/doh.go119
-rw-r--r--plugin/pkg/doh/doh_test.go52
5 files changed, 175 insertions, 134 deletions
diff --git a/core/dnsserver/https.go b/core/dnsserver/https.go
index 915d366ca..532124575 100644
--- a/core/dnsserver/https.go
+++ b/core/dnsserver/https.go
@@ -1,62 +1,11 @@
package dnsserver
import (
- "encoding/base64"
- "fmt"
- "io/ioutil"
"net"
- "net/http"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
- "github.com/miekg/dns"
)
-// mimeTypeDOH is the DoH mimetype that should be used.
-const mimeTypeDOH = "application/dns-message"
-
-// pathDOH is the URL path that should be used.
-const pathDOH = "/dns-query"
-
-// postRequestToMsg extracts the dns message from the request body.
-func postRequestToMsg(req *http.Request) (*dns.Msg, error) {
- defer req.Body.Close()
-
- buf, err := ioutil.ReadAll(req.Body)
- if err != nil {
- return nil, err
- }
- m := new(dns.Msg)
- err = m.Unpack(buf)
- return m, err
-}
-
-// getRequestToMsg extract the dns message from the GET request.
-func getRequestToMsg(req *http.Request) (*dns.Msg, error) {
- values := req.URL.Query()
- b64, ok := values["dns"]
- if !ok {
- return nil, fmt.Errorf("no 'dns' query parameter found")
- }
- if len(b64) != 1 {
- return nil, fmt.Errorf("multiple 'dns' query values found")
- }
- return base64ToMsg(b64[0])
-}
-
-func base64ToMsg(b64 string) (*dns.Msg, error) {
- buf, err := b64Enc.DecodeString(b64)
- if err != nil {
- return nil, err
- }
-
- m := new(dns.Msg)
- err = m.Unpack(buf)
-
- return m, err
-}
-
-var b64Enc = base64.RawURLEncoding
-
// DoHWriter is a nonwriter.Writer that adds more specific LocalAddr and RemoteAddr methods.
type DoHWriter struct {
nonwriter.Writer
diff --git a/core/dnsserver/https_test.go b/core/dnsserver/https_test.go
deleted file mode 100644
index a0ddc4b25..000000000
--- a/core/dnsserver/https_test.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package dnsserver
-
-import (
- "bytes"
- "encoding/base64"
- "net/http"
- "testing"
-
- "github.com/miekg/dns"
-)
-
-func TestPostRequest(t *testing.T) {
- const ex = "example.org."
-
- m := new(dns.Msg)
- m.SetQuestion(ex, dns.TypeDNSKEY)
-
- out, _ := m.Pack()
- req, err := http.NewRequest(http.MethodPost, "https://"+ex+pathDOH+"?bla=foo:443", bytes.NewReader(out))
- if err != nil {
- t.Errorf("Failure to make request: %s", err)
- }
- req.Header.Set("content-type", mimeTypeDOH)
- req.Header.Set("accept", mimeTypeDOH)
-
- m, err = postRequestToMsg(req)
- if err != nil {
- t.Fatalf("Failure to get message from request: %s", err)
- }
-
- if x := m.Question[0].Name; x != ex {
- t.Errorf("Qname expected %s, got %s", ex, x)
- }
- if x := m.Question[0].Qtype; x != dns.TypeDNSKEY {
- t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY)
- }
-}
-
-func TestGetRequest(t *testing.T) {
- const ex = "example.org."
-
- m := new(dns.Msg)
- m.SetQuestion(ex, dns.TypeDNSKEY)
-
- out, _ := m.Pack()
- b64 := base64.RawURLEncoding.EncodeToString(out)
-
- req, err := http.NewRequest(http.MethodGet, "https://"+ex+pathDOH+"?dns="+b64, nil)
- if err != nil {
- t.Errorf("Failure to make request: %s", err)
- }
- req.Header.Set("content-type", mimeTypeDOH)
- req.Header.Set("accept", mimeTypeDOH)
-
- m, err = getRequestToMsg(req)
- if err != nil {
- t.Fatalf("Failure to get message from request: %s", err)
- }
-
- if x := m.Question[0].Name; x != ex {
- t.Errorf("Qname expected %s, got %s", ex, x)
- }
- if x := m.Question[0].Qtype; x != dns.TypeDNSKEY {
- t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY)
- }
-}
diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go
index cf5d08a45..9b1eaaa7e 100644
--- a/core/dnsserver/server_https.go
+++ b/core/dnsserver/server_https.go
@@ -10,9 +10,8 @@ import (
"time"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
+ "github.com/coredns/coredns/plugin/pkg/doh"
"github.com/coredns/coredns/plugin/pkg/response"
-
- "github.com/miekg/dns"
)
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
@@ -99,24 +98,12 @@ func (s *ServerHTTPS) Stop() error {
// chain, converts it back and write it to the client.
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- msg := new(dns.Msg)
- var err error
-
- if r.URL.Path != pathDOH {
+ if r.URL.Path != doh.Path {
http.Error(w, "", http.StatusNotFound)
return
}
- switch r.Method {
- case http.MethodPost:
- msg, err = postRequestToMsg(r)
- case http.MethodGet:
- msg, err = getRequestToMsg(r)
- default:
- http.Error(w, "", http.StatusMethodNotAllowed)
- return
- }
-
+ msg, err := doh.RequestToMsg(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -136,7 +123,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mt, _ := response.Typify(dw.Msg, time.Now().UTC())
age := dnsutil.MinimalTTL(dw.Msg, mt)
- w.Header().Set("Content-Type", mimeTypeDOH)
+ w.Header().Set("Content-Type", doh.MimeType)
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%f", age.Seconds()))
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
w.WriteHeader(http.StatusOK)
diff --git a/plugin/pkg/doh/doh.go b/plugin/pkg/doh/doh.go
new file mode 100644
index 000000000..e0a398e9c
--- /dev/null
+++ b/plugin/pkg/doh/doh.go
@@ -0,0 +1,119 @@
+package doh
+
+import (
+ "bytes"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+
+ "github.com/miekg/dns"
+)
+
+// MimeType is the DoH mimetype that should be used.
+const MimeType = "application/dns-message"
+
+// Path is the URL path that should be used.
+const Path = "/dns-query"
+
+// NewRequest returns a new DoH request given a method, URL (without any paths, so exclude /dns-query) and dns.Msg.
+func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) {
+ buf, err := m.Pack()
+ if err != nil {
+ return nil, err
+ }
+
+ switch method {
+ case http.MethodGet:
+ b64 := base64.RawURLEncoding.EncodeToString(buf)
+
+ req, err := http.NewRequest(http.MethodGet, "https://"+url+Path+"?dns="+b64, nil)
+ if err != nil {
+ return req, err
+ }
+
+ req.Header.Set("content-type", MimeType)
+ req.Header.Set("accept", MimeType)
+ return req, nil
+
+ case http.MethodPost:
+ req, err := http.NewRequest(http.MethodPost, "https://"+url+Path+"?bla=foo:443", bytes.NewReader(buf))
+ if err != nil {
+ return req, err
+ }
+
+ req.Header.Set("content-type", MimeType)
+ req.Header.Set("accept", MimeType)
+ return req, nil
+
+ default:
+ return nil, fmt.Errorf("method not allowed: %s", method)
+ }
+
+}
+
+// ResponseToMsg converts a http.Repsonse to a dns message.
+func ResponseToMsg(resp *http.Response) (*dns.Msg, error) {
+ defer resp.Body.Close()
+
+ return toMsg(resp.Body)
+}
+
+// RequestToMsg converts a http.Request to a dns message.
+func RequestToMsg(req *http.Request) (*dns.Msg, error) {
+ switch req.Method {
+ case http.MethodGet:
+ return requestToMsgGet(req)
+
+ case http.MethodPost:
+ return requestToMsgPost(req)
+
+ default:
+ return nil, fmt.Errorf("method not allowed: %s", req.Method)
+ }
+
+}
+
+// requestToMsgPost extracts the dns message from the request body.
+func requestToMsgPost(req *http.Request) (*dns.Msg, error) {
+ defer req.Body.Close()
+ return toMsg(req.Body)
+}
+
+// requestToMsgGet extract the dns message from the GET request.
+func requestToMsgGet(req *http.Request) (*dns.Msg, error) {
+ values := req.URL.Query()
+ b64, ok := values["dns"]
+ if !ok {
+ return nil, fmt.Errorf("no 'dns' query parameter found")
+ }
+ if len(b64) != 1 {
+ return nil, fmt.Errorf("multiple 'dns' query values found")
+ }
+ return base64ToMsg(b64[0])
+}
+
+func toMsg(r io.ReadCloser) (*dns.Msg, error) {
+ buf, err := ioutil.ReadAll(r)
+ if err != nil {
+ return nil, err
+ }
+ m := new(dns.Msg)
+ err = m.Unpack(buf)
+ return m, err
+}
+
+func base64ToMsg(b64 string) (*dns.Msg, error) {
+ buf, err := b64Enc.DecodeString(b64)
+ if err != nil {
+ return nil, err
+ }
+
+ m := new(dns.Msg)
+ err = m.Unpack(buf)
+
+ return m, err
+}
+
+var b64Enc = base64.RawURLEncoding
diff --git a/plugin/pkg/doh/doh_test.go b/plugin/pkg/doh/doh_test.go
new file mode 100644
index 000000000..449166151
--- /dev/null
+++ b/plugin/pkg/doh/doh_test.go
@@ -0,0 +1,52 @@
+package doh
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestPostRequest(t *testing.T) {
+ m := new(dns.Msg)
+ m.SetQuestion("example.org.", dns.TypeDNSKEY)
+
+ req, err := NewRequest(http.MethodPost, "https://example.org:443", m)
+ if err != nil {
+ t.Errorf("Failure to make request: %s", err)
+ }
+
+ m, err = RequestToMsg(req)
+ if err != nil {
+ t.Fatalf("Failure to get message from request: %s", err)
+ }
+
+ if x := m.Question[0].Name; x != "example.org." {
+ t.Errorf("Qname expected %s, got %s", "example.org.", x)
+ }
+ if x := m.Question[0].Qtype; x != dns.TypeDNSKEY {
+ t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY)
+ }
+}
+
+func TestGetRequest(t *testing.T) {
+ m := new(dns.Msg)
+ m.SetQuestion("example.org.", dns.TypeDNSKEY)
+
+ req, err := NewRequest(http.MethodGet, "https://example.org:443", m)
+ if err != nil {
+ t.Errorf("Failure to make request: %s", err)
+ }
+
+ m, err = RequestToMsg(req)
+ if err != nil {
+ t.Fatalf("Failure to get message from request: %s", err)
+ }
+
+ if x := m.Question[0].Name; x != "example.org." {
+ t.Errorf("Qname expected %s, got %s", "example.org.", x)
+ }
+ if x := m.Question[0].Qtype; x != dns.TypeDNSKEY {
+ t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY)
+ }
+}