aboutsummaryrefslogtreecommitdiff
path: root/plugin/proxy/dns.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/proxy/dns.go')
-rw-r--r--plugin/proxy/dns.go106
1 files changed, 106 insertions, 0 deletions
diff --git a/plugin/proxy/dns.go b/plugin/proxy/dns.go
new file mode 100644
index 000000000..4d8038422
--- /dev/null
+++ b/plugin/proxy/dns.go
@@ -0,0 +1,106 @@
+package proxy
+
+import (
+ "context"
+ "net"
+ "time"
+
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+type dnsEx struct {
+ Timeout time.Duration
+ Options
+}
+
+// Options define the options understood by dns.Exchange.
+type Options struct {
+ ForceTCP bool // If true use TCP for upstream no matter what
+}
+
+func newDNSEx() *dnsEx {
+ return newDNSExWithOption(Options{})
+}
+
+func newDNSExWithOption(opt Options) *dnsEx {
+ return &dnsEx{Timeout: defaultTimeout * time.Second, Options: opt}
+}
+
+func (d *dnsEx) Transport() string {
+ if d.Options.ForceTCP {
+ return "tcp"
+ }
+
+ // The protocol will be determined by `state.Proto()` during Exchange.
+ return ""
+}
+func (d *dnsEx) Protocol() string { return "dns" }
+func (d *dnsEx) OnShutdown(p *Proxy) error { return nil }
+func (d *dnsEx) OnStartup(p *Proxy) error { return nil }
+
+// Exchange implements the Exchanger interface.
+func (d *dnsEx) Exchange(ctx context.Context, addr string, state request.Request) (*dns.Msg, error) {
+ proto := state.Proto()
+ if d.Options.ForceTCP {
+ proto = "tcp"
+ }
+ co, err := net.DialTimeout(proto, addr, d.Timeout)
+ if err != nil {
+ return nil, err
+ }
+
+ reply, _, err := d.ExchangeConn(state.Req, co)
+
+ co.Close()
+
+ if reply != nil && reply.Truncated {
+ // Suppress proxy error for truncated responses
+ err = nil
+ }
+
+ if err != nil {
+ return nil, err
+ }
+ // Make sure it fits in the DNS response.
+ reply, _ = state.Scrub(reply)
+ reply.Compress = true
+ reply.Id = state.Req.Id
+
+ return reply, nil
+}
+
+func (d *dnsEx) ExchangeConn(m *dns.Msg, co net.Conn) (*dns.Msg, time.Duration, error) {
+ start := time.Now()
+ r, err := exchange(m, co)
+ rtt := time.Since(start)
+
+ return r, rtt, err
+}
+
+func exchange(m *dns.Msg, co net.Conn) (*dns.Msg, error) {
+ opt := m.IsEdns0()
+
+ udpsize := uint16(dns.MinMsgSize)
+ // If EDNS0 is used use that for size.
+ if opt != nil && opt.UDPSize() >= dns.MinMsgSize {
+ udpsize = opt.UDPSize()
+ }
+
+ dnsco := &dns.Conn{Conn: co, UDPSize: udpsize}
+
+ writeDeadline := time.Now().Add(defaultTimeout)
+ dnsco.SetWriteDeadline(writeDeadline)
+ dnsco.WriteMsg(m)
+
+ readDeadline := time.Now().Add(defaultTimeout)
+ co.SetReadDeadline(readDeadline)
+ r, err := dnsco.ReadMsg()
+
+ dnsco.Close()
+ if r == nil {
+ return nil, err
+ }
+ return r, err
+}