diff options
Diffstat (limited to 'plugin/proxy/dns.go')
-rw-r--r-- | plugin/proxy/dns.go | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/plugin/proxy/dns.go b/plugin/proxy/dns.go new file mode 100644 index 000000000..4d8038422 --- /dev/null +++ b/plugin/proxy/dns.go @@ -0,0 +1,106 @@ +package proxy + +import ( + "context" + "net" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type dnsEx struct { + Timeout time.Duration + Options +} + +// Options define the options understood by dns.Exchange. +type Options struct { + ForceTCP bool // If true use TCP for upstream no matter what +} + +func newDNSEx() *dnsEx { + return newDNSExWithOption(Options{}) +} + +func newDNSExWithOption(opt Options) *dnsEx { + return &dnsEx{Timeout: defaultTimeout * time.Second, Options: opt} +} + +func (d *dnsEx) Transport() string { + if d.Options.ForceTCP { + return "tcp" + } + + // The protocol will be determined by `state.Proto()` during Exchange. + return "" +} +func (d *dnsEx) Protocol() string { return "dns" } +func (d *dnsEx) OnShutdown(p *Proxy) error { return nil } +func (d *dnsEx) OnStartup(p *Proxy) error { return nil } + +// Exchange implements the Exchanger interface. +func (d *dnsEx) Exchange(ctx context.Context, addr string, state request.Request) (*dns.Msg, error) { + proto := state.Proto() + if d.Options.ForceTCP { + proto = "tcp" + } + co, err := net.DialTimeout(proto, addr, d.Timeout) + if err != nil { + return nil, err + } + + reply, _, err := d.ExchangeConn(state.Req, co) + + co.Close() + + if reply != nil && reply.Truncated { + // Suppress proxy error for truncated responses + err = nil + } + + if err != nil { + return nil, err + } + // Make sure it fits in the DNS response. + reply, _ = state.Scrub(reply) + reply.Compress = true + reply.Id = state.Req.Id + + return reply, nil +} + +func (d *dnsEx) ExchangeConn(m *dns.Msg, co net.Conn) (*dns.Msg, time.Duration, error) { + start := time.Now() + r, err := exchange(m, co) + rtt := time.Since(start) + + return r, rtt, err +} + +func exchange(m *dns.Msg, co net.Conn) (*dns.Msg, error) { + opt := m.IsEdns0() + + udpsize := uint16(dns.MinMsgSize) + // If EDNS0 is used use that for size. + if opt != nil && opt.UDPSize() >= dns.MinMsgSize { + udpsize = opt.UDPSize() + } + + dnsco := &dns.Conn{Conn: co, UDPSize: udpsize} + + writeDeadline := time.Now().Add(defaultTimeout) + dnsco.SetWriteDeadline(writeDeadline) + dnsco.WriteMsg(m) + + readDeadline := time.Now().Add(defaultTimeout) + co.SetReadDeadline(readDeadline) + r, err := dnsco.ReadMsg() + + dnsco.Close() + if r == nil { + return nil, err + } + return r, err +} |