diff options
Diffstat (limited to 'core/dnsserver/server_https.go')
-rw-r--r-- | core/dnsserver/server_https.go | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index 27757861c..057dac49c 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -20,12 +20,13 @@ import ( // ServerHTTPS represents an instance of a DNS-over-HTTPS server. type ServerHTTPS struct { *Server - httpsServer *http.Server - listenAddr net.Addr - tlsConfig *tls.Config + httpsServer *http.Server + listenAddr net.Addr + tlsConfig *tls.Config + validRequest func(*http.Request) bool } -// NewServerHTTPS returns a new CoreDNS GRPC server and compiles all plugins in to it. +// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it. func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { s, err := NewServer(addr, group) if err != nil { @@ -45,12 +46,23 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { // or the upgrade won't happen. tlsConfig.NextProtos = []string{"h2", "http/1.1"} + // Use a custom request validation func or use the standard DoH path check. + var validator func(*http.Request) bool + for _, conf := range s.zones { + validator = conf.HTTPRequestValidateFunc + } + if validator == nil { + validator = func(r *http.Request) bool { return r.URL.Path == doh.Path } + } + srv := &http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, } - sh := &ServerHTTPS{Server: s, tlsConfig: tlsConfig, httpsServer: srv} + sh := &ServerHTTPS{ + Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator, + } sh.httpsServer.Handler = sh return sh, nil @@ -114,7 +126,7 @@ func (s *ServerHTTPS) Stop() error { // chain, converts it back and write it to the client. func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != doh.Path { + if !s.validRequest(r) { http.Error(w, "", http.StatusNotFound) return } |