aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/dnsserver/config.go6
-rw-r--r--core/dnsserver/server_https.go24
-rw-r--r--core/dnsserver/server_https_test.go66
3 files changed, 90 insertions, 6 deletions
diff --git a/core/dnsserver/config.go b/core/dnsserver/config.go
index 6a720ef13..4ff2ecda1 100644
--- a/core/dnsserver/config.go
+++ b/core/dnsserver/config.go
@@ -3,6 +3,7 @@ package dnsserver
import (
"crypto/tls"
"fmt"
+ "net/http"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
@@ -31,6 +32,11 @@ type Config struct {
// DNS-over-TLS or DNS-over-gRPC.
Transport string
+ // If this function is not nil it will be used to inspect and validate
+ // HTTP requests. Although this isn't referenced in-tree, external plugins
+ // may depend on it.
+ HTTPRequestValidateFunc func(*http.Request) bool
+
// If this function is not nil it will be used to further filter access
// to this handler. The primary use is to limit access to a reverse zone
// on a non-octet boundary, i.e. /17
diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go
index 27757861c..057dac49c 100644
--- a/core/dnsserver/server_https.go
+++ b/core/dnsserver/server_https.go
@@ -20,12 +20,13 @@ import (
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
type ServerHTTPS struct {
*Server
- httpsServer *http.Server
- listenAddr net.Addr
- tlsConfig *tls.Config
+ httpsServer *http.Server
+ listenAddr net.Addr
+ tlsConfig *tls.Config
+ validRequest func(*http.Request) bool
}
-// NewServerHTTPS returns a new CoreDNS GRPC server and compiles all plugins in to it.
+// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
s, err := NewServer(addr, group)
if err != nil {
@@ -45,12 +46,23 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
// or the upgrade won't happen.
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
+ // Use a custom request validation func or use the standard DoH path check.
+ var validator func(*http.Request) bool
+ for _, conf := range s.zones {
+ validator = conf.HTTPRequestValidateFunc
+ }
+ if validator == nil {
+ validator = func(r *http.Request) bool { return r.URL.Path == doh.Path }
+ }
+
srv := &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
}
- sh := &ServerHTTPS{Server: s, tlsConfig: tlsConfig, httpsServer: srv}
+ sh := &ServerHTTPS{
+ Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator,
+ }
sh.httpsServer.Handler = sh
return sh, nil
@@ -114,7 +126,7 @@ func (s *ServerHTTPS) Stop() error {
// chain, converts it back and write it to the client.
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != doh.Path {
+ if !s.validRequest(r) {
http.Error(w, "", http.StatusNotFound)
return
}
diff --git a/core/dnsserver/server_https_test.go b/core/dnsserver/server_https_test.go
new file mode 100644
index 000000000..6663c1075
--- /dev/null
+++ b/core/dnsserver/server_https_test.go
@@ -0,0 +1,66 @@
+package dnsserver
+
+import (
+ "bytes"
+ "crypto/tls"
+ "net/http"
+ "net/http/httptest"
+ "regexp"
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+var (
+ validPath = regexp.MustCompile("^/(dns-query|(?P<uuid>[0-9a-f]+))$")
+ validator = func(r *http.Request) bool { return validPath.MatchString(r.URL.Path) }
+)
+
+func testServerHTTPS(t *testing.T, path string, validator func(*http.Request) bool) *http.Response {
+ c := Config{
+ Zone: "example.com.",
+ Transport: "https",
+ TLSConfig: &tls.Config{},
+ ListenHosts: []string{"127.0.0.1"},
+ Port: "443",
+ HTTPRequestValidateFunc: validator,
+ }
+ s, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
+ if err != nil {
+ t.Log(err)
+ t.Fatal("could not create HTTPS server")
+ }
+ m := new(dns.Msg)
+ m.SetQuestion("example.org.", dns.TypeDNSKEY)
+ buf, err := m.Pack()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ r := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(buf))
+ w := httptest.NewRecorder()
+ s.ServeHTTP(w, r)
+
+ return w.Result()
+}
+
+func TestCustomHTTPRequestValidator(t *testing.T) {
+ testCases := map[string]struct {
+ path string
+ expected int
+ validator func(*http.Request) bool
+ }{
+ "default": {"/dns-query", http.StatusOK, nil},
+ "custom validator": {"/b10cada", http.StatusOK, validator},
+ "no validator set": {"/adb10c", http.StatusNotFound, nil},
+ "invalid path with validator": {"/helloworld", http.StatusNotFound, validator},
+ }
+ for name, tc := range testCases {
+ t.Run(name, func(t *testing.T) {
+ res := testServerHTTPS(t, tc.path, tc.validator)
+ if res.StatusCode != tc.expected {
+ t.Error("unexpected HTTP code", res.StatusCode)
+ }
+ })
+ }
+}