diff options
-rw-r--r-- | core/dnsserver/https.go | 18 | ||||
-rw-r--r-- | core/dnsserver/server-https.go | 9 | ||||
-rw-r--r-- | plugin/pkg/nonwriter/nonwriter.go | 13 | ||||
-rw-r--r-- | test/file_upstream_test.go | 60 |
4 files changed, 81 insertions, 19 deletions
diff --git a/core/dnsserver/https.go b/core/dnsserver/https.go index 028b74709..915d366ca 100644 --- a/core/dnsserver/https.go +++ b/core/dnsserver/https.go @@ -4,8 +4,10 @@ import ( "encoding/base64" "fmt" "io/ioutil" + "net" "net/http" + "github.com/coredns/coredns/plugin/pkg/nonwriter" "github.com/miekg/dns" ) @@ -54,3 +56,19 @@ func base64ToMsg(b64 string) (*dns.Msg, error) { } var b64Enc = base64.RawURLEncoding + +// DoHWriter is a nonwriter.Writer that adds more specific LocalAddr and RemoteAddr methods. +type DoHWriter struct { + nonwriter.Writer + + // raddr is the remote's address. This can be optionally set. + raddr net.Addr + // laddr is our address. This can be optionally set. + laddr net.Addr +} + +// RemoteAddr returns the remote address. +func (d *DoHWriter) RemoteAddr() net.Addr { return d.raddr } + +// LocalAddr returns the local address. +func (d *DoHWriter) LocalAddr() net.Addr { return d.laddr } diff --git a/core/dnsserver/server-https.go b/core/dnsserver/server-https.go index f460f0ff4..c9f0da0cd 100644 --- a/core/dnsserver/server-https.go +++ b/core/dnsserver/server-https.go @@ -8,7 +8,6 @@ import ( "net/http" "strconv" - "github.com/coredns/coredns/plugin/pkg/nonwriter" "github.com/miekg/dns" ) @@ -119,12 +118,10 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Create a non-writer with the correct addresses in it. - dw := &nonwriter.Writer{Laddr: s.listenAddr} + // Create a DoHWriter with the correct addresses in it. h, p, _ := net.SplitHostPort(r.RemoteAddr) - po, _ := strconv.Atoi(p) - ip := net.ParseIP(h) - dw.Raddr = &net.TCPAddr{IP: ip, Port: po} + port, _ := strconv.Atoi(p) + dw := &DoHWriter{laddr: s.listenAddr, raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port}} // We just call the normal chain handler - all error handling is done there. // We should expect a packet to be returned that we can send to the client. diff --git a/plugin/pkg/nonwriter/nonwriter.go b/plugin/pkg/nonwriter/nonwriter.go index b157e4242..411e98a94 100644 --- a/plugin/pkg/nonwriter/nonwriter.go +++ b/plugin/pkg/nonwriter/nonwriter.go @@ -2,8 +2,6 @@ package nonwriter import ( - "net" - "github.com/miekg/dns" ) @@ -11,11 +9,6 @@ import ( type Writer struct { dns.ResponseWriter Msg *dns.Msg - - // Raddr is the remote's address. This can be optionally set. - Raddr net.Addr - // Laddr is our address. This can be optionally set. - Laddr net.Addr } // New makes and returns a new NonWriter. @@ -26,9 +19,3 @@ func (w *Writer) WriteMsg(res *dns.Msg) error { w.Msg = res return nil } - -// RemoteAddr returns the remote address. -func (w *Writer) RemoteAddr() net.Addr { return w.Raddr } - -// LocalAddr returns the local address. -func (w *Writer) LocalAddr() net.Addr { return w.Laddr } diff --git a/test/file_upstream_test.go b/test/file_upstream_test.go new file mode 100644 index 000000000..5a24e12c4 --- /dev/null +++ b/test/file_upstream_test.go @@ -0,0 +1,60 @@ +package test + +import ( + "testing" + + "github.com/miekg/dns" +) + +// TODO(miek): this test needs to be fleshed out. + +func TestFileUpstream(t *testing.T) { + name, rm, err := TempFile(".", `$ORIGIN example.org. +@ 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. ( + 2017042745 ; serial + 7200 ; refresh (2 hours) + 3600 ; retry (1 hour) + 1209600 ; expire (2 weeks) + 3600 ; minimum (1 hour) + ) + + 3600 IN NS a.iana-servers.net. + 3600 IN NS b.iana-servers.net. + +www 3600 IN CNAME www.example.net. +`) + if err != nil { + t.Fatalf("Failed to create zone: %s", err) + } + defer rm() + + // Corefile with for example without proxy section. + corefile := `example.org:0 { + file ` + name + ` { + upstream + } + hosts { + 10.0.0.1 www.example.net. + fallthrough + } +} +` + i, udp, _, err := CoreDNSServerAndPorts(corefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer i.Stop() + + m := new(dns.Msg) + m.SetQuestion("www.example.org.", dns.TypeA) + m.SetEdns0(4096, true) + + r, err := dns.Exchange(m, udp) + if err != nil { + t.Fatalf("Could not exchange msg: %s", err) + } + if r.Rcode == dns.RcodeServerFailure { + t.Fatalf("Rcode should not be dns.RcodeServerFailure") + } + t.Logf("%s", r) +} |