diff options
author | 2023-03-03 15:44:38 +0100 | |
---|---|---|
committer | 2023-03-03 09:44:38 -0500 | |
commit | 80b40c159e856230dae6eea2eed44ee5fc6bea86 (patch) | |
tree | b69e3e32ec8770c2b5f4afdfc945f8ddbe008736 /plugin | |
parent | 03fb2fa747a762d1be6a4c0cdafa146edbffa6eb (diff) | |
download | coredns-80b40c159e856230dae6eea2eed44ee5fc6bea86.tar.gz coredns-80b40c159e856230dae6eea2eed44ee5fc6bea86.tar.zst coredns-80b40c159e856230dae6eea2eed44ee5fc6bea86.zip |
DoH: Allow http as the protocol (#5762)
This change avoids the hard coding of HTTPS, allowing flexibility in whether HTTP or HTTPS is used.
Signed-off-by: Sebastian Dahlgren <sebdah@fb.com>
Diffstat (limited to 'plugin')
-rw-r--r-- | plugin/pkg/doh/doh.go | 23 | ||||
-rw-r--r-- | plugin/pkg/doh/doh_test.go | 76 |
2 files changed, 55 insertions, 44 deletions
diff --git a/plugin/pkg/doh/doh.go b/plugin/pkg/doh/doh.go index 9d5305b34..faddfc8aa 100644 --- a/plugin/pkg/doh/doh.go +++ b/plugin/pkg/doh/doh.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/miekg/dns" ) @@ -16,18 +17,30 @@ const MimeType = "application/dns-message" // Path is the URL path that should be used. const Path = "/dns-query" -// NewRequest returns a new DoH request given a method, URL (without any paths, so exclude /dns-query) and dns.Msg. +// NewRequest returns a new DoH request given a HTTP method, URL and dns.Msg. +// +// The URL should not have a path, so please exclude /dns-query. The URL will +// be prefixed with https:// by default, unless it's already prefixed with +// either http:// or https://. func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) { buf, err := m.Pack() if err != nil { return nil, err } + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = fmt.Sprintf("https://%s", url) + } + switch method { case http.MethodGet: b64 := base64.RawURLEncoding.EncodeToString(buf) - req, err := http.NewRequest(http.MethodGet, "https://"+url+Path+"?dns="+b64, nil) + req, err := http.NewRequest( + http.MethodGet, + fmt.Sprintf("%s%s?dns=%s", url, Path, b64), + nil, + ) if err != nil { return req, err } @@ -37,7 +50,11 @@ func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) { return req, nil case http.MethodPost: - req, err := http.NewRequest(http.MethodPost, "https://"+url+Path+"?bla=foo:443", bytes.NewReader(buf)) + req, err := http.NewRequest( + http.MethodPost, + fmt.Sprintf("%s%s?bla=foo:443", url, Path), + bytes.NewReader(buf), + ) if err != nil { return req, err } diff --git a/plugin/pkg/doh/doh_test.go b/plugin/pkg/doh/doh_test.go index 449166151..047d0136d 100644 --- a/plugin/pkg/doh/doh_test.go +++ b/plugin/pkg/doh/doh_test.go @@ -7,46 +7,40 @@ import ( "github.com/miekg/dns" ) -func TestPostRequest(t *testing.T) { - m := new(dns.Msg) - m.SetQuestion("example.org.", dns.TypeDNSKEY) - - req, err := NewRequest(http.MethodPost, "https://example.org:443", m) - if err != nil { - t.Errorf("Failure to make request: %s", err) - } - - m, err = RequestToMsg(req) - if err != nil { - t.Fatalf("Failure to get message from request: %s", err) - } - - if x := m.Question[0].Name; x != "example.org." { - t.Errorf("Qname expected %s, got %s", "example.org.", x) - } - if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { - t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) - } -} - -func TestGetRequest(t *testing.T) { - m := new(dns.Msg) - m.SetQuestion("example.org.", dns.TypeDNSKEY) - - req, err := NewRequest(http.MethodGet, "https://example.org:443", m) - if err != nil { - t.Errorf("Failure to make request: %s", err) - } - - m, err = RequestToMsg(req) - if err != nil { - t.Fatalf("Failure to get message from request: %s", err) - } - - if x := m.Question[0].Name; x != "example.org." { - t.Errorf("Qname expected %s, got %s", "example.org.", x) - } - if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { - t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) +func TestDoH(t *testing.T) { + tests := map[string]struct { + method string + url string + }{ + "POST request over HTTPS": {method: http.MethodPost, url: "https://example.org:443"}, + "POST request over HTTP": {method: http.MethodPost, url: "http://example.org:443"}, + "POST request without protocol": {method: http.MethodPost, url: "example.org:443"}, + "GET request over HTTPS": {method: http.MethodGet, url: "https://example.org:443"}, + "GET request over HTTP": {method: http.MethodGet, url: "http://example.org"}, + "GET request without protocol": {method: http.MethodGet, url: "example.org:443"}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeDNSKEY) + + req, err := NewRequest(test.method, test.url, m) + if err != nil { + t.Errorf("Failure to make request: %s", err) + } + + m, err = RequestToMsg(req) + if err != nil { + t.Fatalf("Failure to get message from request: %s", err) + } + + if x := m.Question[0].Name; x != "example.org." { + t.Errorf("Qname expected %s, got %s", "example.org.", x) + } + if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { + t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) + } + }) } } |