aboutsummaryrefslogtreecommitdiff
path: root/plugin/pkg/proxy
diff options
context:
space:
mode:
authorGravatar Pat Downey <patdowney@users.noreply.github.com> 2023-03-24 12:55:51 +0000
committerGravatar GitHub <noreply@github.com> 2023-03-24 08:55:51 -0400
commitf823825f8a34edb85d5d18cd5d2f6f850adf408e (patch)
tree79d241ab9b4c7c343d806f4041c8efccbe3f9ca0 /plugin/pkg/proxy
parent47dceabfc6465ba6c5d41472d6602d4ad5c9fb1b (diff)
downloadcoredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.gz
coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.zst
coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.zip
plugin/forward: Allow Proxy to be used outside of forward plugin. (#5951)
* plugin/forward: Move Proxy into pkg/plugin/proxy, to allow forward.Proxy to be used outside of forward plugin. Signed-off-by: Patrick Downey <patrick.downey@dioadconsulting.com>
Diffstat (limited to 'plugin/pkg/proxy')
-rw-r--r--plugin/pkg/proxy/connect.go152
-rw-r--r--plugin/pkg/proxy/errors.go26
-rw-r--r--plugin/pkg/proxy/health.go131
-rw-r--r--plugin/pkg/proxy/health_test.go153
-rw-r--r--plugin/pkg/proxy/metrics.go49
-rw-r--r--plugin/pkg/proxy/persistent.go156
-rw-r--r--plugin/pkg/proxy/persistent_test.go109
-rw-r--r--plugin/pkg/proxy/proxy.go98
-rw-r--r--plugin/pkg/proxy/proxy_test.go99
-rw-r--r--plugin/pkg/proxy/type.go39
10 files changed, 1012 insertions, 0 deletions
diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go
new file mode 100644
index 000000000..29274d92d
--- /dev/null
+++ b/plugin/pkg/proxy/connect.go
@@ -0,0 +1,152 @@
+// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
+// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
+// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
+// inband healthchecking.
+package proxy
+
+import (
+ "context"
+ "io"
+ "strconv"
+ "sync/atomic"
+ "time"
+
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+// limitTimeout is a utility function to auto-tune timeout values
+// average observed time is moved towards the last observed delay moderated by a weight
+// next timeout to use will be the double of the computed average, limited by min and max frame.
+func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
+ rt := time.Duration(atomic.LoadInt64(currentAvg))
+ if rt < minValue {
+ return minValue
+ }
+ if rt < maxValue/2 {
+ return 2 * rt
+ }
+ return maxValue
+}
+
+func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
+ dt := time.Duration(atomic.LoadInt64(currentAvg))
+ atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
+}
+
+func (t *Transport) dialTimeout() time.Duration {
+ return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
+}
+
+func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
+ averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
+}
+
+// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
+func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
+ // If tls has been configured; use it.
+ if t.tlsConfig != nil {
+ proto = "tcp-tls"
+ }
+
+ t.dial <- proto
+ pc := <-t.ret
+
+ if pc != nil {
+ ConnCacheHitsCount.WithLabelValues(t.addr, proto).Add(1)
+ return pc, true, nil
+ }
+ ConnCacheMissesCount.WithLabelValues(t.addr, proto).Add(1)
+
+ reqTime := time.Now()
+ timeout := t.dialTimeout()
+ if proto == "tcp-tls" {
+ conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+ }
+ conn, err := dns.DialTimeout(proto, t.addr, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+}
+
+// Connect selects an upstream, sends the request and waits for a response.
+func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
+ start := time.Now()
+
+ proto := ""
+ switch {
+ case opts.ForceTCP: // TCP flag has precedence over UDP flag
+ proto = "tcp"
+ case opts.PreferUDP:
+ proto = "udp"
+ default:
+ proto = state.Proto()
+ }
+
+ pc, cached, err := p.transport.Dial(proto)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set buffer size correctly for this client.
+ pc.c.UDPSize = uint16(state.Size())
+ if pc.c.UDPSize < 512 {
+ pc.c.UDPSize = 512
+ }
+
+ pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
+ // records the origin Id before upstream.
+ originId := state.Req.Id
+ state.Req.Id = dns.Id()
+ defer func() {
+ state.Req.Id = originId
+ }()
+
+ if err := pc.c.WriteMsg(state.Req); err != nil {
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ return nil, err
+ }
+
+ var ret *dns.Msg
+ pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
+ for {
+ ret, err = pc.c.ReadMsg()
+ if err != nil {
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ // recovery the origin Id after upstream.
+ if ret != nil {
+ ret.Id = originId
+ }
+ return ret, err
+ }
+ // drop out-of-order responses
+ if state.Req.Id == ret.Id {
+ break
+ }
+ }
+ // recovery the origin Id after upstream.
+ ret.Id = originId
+
+ p.transport.Yield(pc)
+
+ rc, ok := dns.RcodeToString[ret.Rcode]
+ if !ok {
+ rc = strconv.Itoa(ret.Rcode)
+ }
+
+ RequestCount.WithLabelValues(p.addr).Add(1)
+ RcodeCount.WithLabelValues(rc, p.addr).Add(1)
+ RequestDuration.WithLabelValues(p.addr, rc).Observe(time.Since(start).Seconds())
+
+ return ret, nil
+}
+
+const cumulativeAvgWeight = 4
diff --git a/plugin/pkg/proxy/errors.go b/plugin/pkg/proxy/errors.go
new file mode 100644
index 000000000..461236423
--- /dev/null
+++ b/plugin/pkg/proxy/errors.go
@@ -0,0 +1,26 @@
+package proxy
+
+import (
+ "errors"
+)
+
+var (
+ // ErrNoHealthy means no healthy proxies left.
+ ErrNoHealthy = errors.New("no healthy proxies")
+ // ErrNoForward means no forwarder defined.
+ ErrNoForward = errors.New("no forwarder defined")
+ // ErrCachedClosed means cached connection was closed by peer.
+ ErrCachedClosed = errors.New("cached connection was closed by peer")
+)
+
+// Options holds various Options that can be set.
+type Options struct {
+ // ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag
+ ForceTCP bool
+ // PreferUDP use UDP protocol for upstream DNS request.
+ PreferUDP bool
+ // HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests
+ HCRecursionDesired bool
+ // HCDomain sets domain for Proxy healthcheck requests
+ HCDomain string
+}
diff --git a/plugin/pkg/proxy/health.go b/plugin/pkg/proxy/health.go
new file mode 100644
index 000000000..e87104a13
--- /dev/null
+++ b/plugin/pkg/proxy/health.go
@@ -0,0 +1,131 @@
+package proxy
+
+import (
+ "crypto/tls"
+ "sync/atomic"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/log"
+ "github.com/coredns/coredns/plugin/pkg/transport"
+
+ "github.com/miekg/dns"
+)
+
+// HealthChecker checks the upstream health.
+type HealthChecker interface {
+ Check(*Proxy) error
+ SetTLSConfig(*tls.Config)
+ GetTLSConfig() *tls.Config
+ SetRecursionDesired(bool)
+ GetRecursionDesired() bool
+ SetDomain(domain string)
+ GetDomain() string
+ SetTCPTransport()
+ GetReadTimeout() time.Duration
+ SetReadTimeout(time.Duration)
+ GetWriteTimeout() time.Duration
+ SetWriteTimeout(time.Duration)
+}
+
+// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
+type dnsHc struct {
+ c *dns.Client
+ recursionDesired bool
+ domain string
+}
+
+// NewHealthChecker returns a new HealthChecker based on transport.
+func NewHealthChecker(trans string, recursionDesired bool, domain string) HealthChecker {
+ switch trans {
+ case transport.DNS, transport.TLS:
+ c := new(dns.Client)
+ c.Net = "udp"
+ c.ReadTimeout = 1 * time.Second
+ c.WriteTimeout = 1 * time.Second
+
+ return &dnsHc{
+ c: c,
+ recursionDesired: recursionDesired,
+ domain: domain,
+ }
+ }
+
+ log.Warningf("No healthchecker for transport %q", trans)
+ return nil
+}
+
+func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
+ h.c.Net = "tcp-tls"
+ h.c.TLSConfig = cfg
+}
+
+func (h *dnsHc) GetTLSConfig() *tls.Config {
+ return h.c.TLSConfig
+}
+
+func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
+ h.recursionDesired = recursionDesired
+}
+func (h *dnsHc) GetRecursionDesired() bool {
+ return h.recursionDesired
+}
+
+func (h *dnsHc) SetDomain(domain string) {
+ h.domain = domain
+}
+func (h *dnsHc) GetDomain() string {
+ return h.domain
+}
+
+func (h *dnsHc) SetTCPTransport() {
+ h.c.Net = "tcp"
+}
+
+func (h *dnsHc) GetReadTimeout() time.Duration {
+ return h.c.ReadTimeout
+}
+
+func (h *dnsHc) SetReadTimeout(t time.Duration) {
+ h.c.ReadTimeout = t
+}
+
+func (h *dnsHc) GetWriteTimeout() time.Duration {
+ return h.c.WriteTimeout
+}
+
+func (h *dnsHc) SetWriteTimeout(t time.Duration) {
+ h.c.WriteTimeout = t
+}
+
+// For HC, we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty
+// replies are considered fails, basically anything else constitutes a healthy upstream.
+
+// Check is used as the up.Func in the up.Probe.
+func (h *dnsHc) Check(p *Proxy) error {
+ err := h.send(p.addr)
+ if err != nil {
+ HealthcheckFailureCount.WithLabelValues(p.addr).Add(1)
+ atomic.AddUint32(&p.fails, 1)
+ return err
+ }
+
+ atomic.StoreUint32(&p.fails, 0)
+ return nil
+}
+
+func (h *dnsHc) send(addr string) error {
+ ping := new(dns.Msg)
+ ping.SetQuestion(h.domain, dns.TypeNS)
+ ping.MsgHdr.RecursionDesired = h.recursionDesired
+
+ m, _, err := h.c.Exchange(ping, addr)
+ // If we got a header, we're alright, basically only care about I/O errors 'n stuff.
+ if err != nil && m != nil {
+ // Silly check, something sane came back.
+ if m.Response || m.Opcode == dns.OpcodeQuery {
+ err = nil
+ }
+ }
+
+ return err
+}
diff --git a/plugin/pkg/proxy/health_test.go b/plugin/pkg/proxy/health_test.go
new file mode 100644
index 000000000..c1b5270ad
--- /dev/null
+++ b/plugin/pkg/proxy/health_test.go
@@ -0,0 +1,153 @@
+package proxy
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/pkg/transport"
+
+ "github.com/miekg/dns"
+)
+
+func TestHealth(t *testing.T) {
+ i := uint32(0)
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(transport.DNS, true, "")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr, transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthTCP(t *testing.T) {
+ i := uint32(0)
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(transport.DNS, true, "")
+ hc.SetTCPTransport()
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr, transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthNoRecursion(t *testing.T) {
+ i := uint32(0)
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == "." && r.RecursionDesired == false {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(transport.DNS, false, "")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr, transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(20 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1)
+ }
+}
+
+func TestHealthTimeout(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ // timeout
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(transport.DNS, false, "")
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr, transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err == nil {
+ t.Errorf("expected error")
+ }
+}
+
+func TestHealthDomain(t *testing.T) {
+ hcDomain := "example.org."
+
+ i := uint32(0)
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ if r.Question[0].Name == hcDomain && r.RecursionDesired == true {
+ atomic.AddUint32(&i, 1)
+ }
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ hc := NewHealthChecker(transport.DNS, true, hcDomain)
+ hc.SetReadTimeout(10 * time.Millisecond)
+ hc.SetWriteTimeout(10 * time.Millisecond)
+
+ p := NewProxy(s.Addr, transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+ err := hc.Check(p)
+ if err != nil {
+ t.Errorf("check failed: %v", err)
+ }
+
+ time.Sleep(12 * time.Millisecond)
+ i1 := atomic.LoadUint32(&i)
+ if i1 != 1 {
+ t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1)
+ }
+}
diff --git a/plugin/pkg/proxy/metrics.go b/plugin/pkg/proxy/metrics.go
new file mode 100644
index 000000000..148bc6edd
--- /dev/null
+++ b/plugin/pkg/proxy/metrics.go
@@ -0,0 +1,49 @@
+package proxy
+
+import (
+ "github.com/coredns/coredns/plugin"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+// Variables declared for monitoring.
+var (
+ RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "proxy",
+ Name: "requests_total",
+ Help: "Counter of requests made per upstream.",
+ }, []string{"to"})
+ RcodeCount = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "proxy",
+ Name: "responses_total",
+ Help: "Counter of responses received per upstream.",
+ }, []string{"rcode", "to"})
+ RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "proxy",
+ Name: "request_duration_seconds",
+ Buckets: plugin.TimeBuckets,
+ Help: "Histogram of the time each request took.",
+ }, []string{"to", "rcode"})
+ HealthcheckFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "proxy",
+ Name: "healthcheck_failures_total",
+ Help: "Counter of the number of failed healthchecks.",
+ }, []string{"to"})
+ ConnCacheHitsCount = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "proxy",
+ Name: "conn_cache_hits_total",
+ Help: "Counter of connection cache hits per upstream and protocol.",
+ }, []string{"to", "proto"})
+ ConnCacheMissesCount = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "proxy",
+ Name: "conn_cache_misses_total",
+ Help: "Counter of connection cache misses per upstream and protocol.",
+ }, []string{"to", "proto"})
+)
diff --git a/plugin/pkg/proxy/persistent.go b/plugin/pkg/proxy/persistent.go
new file mode 100644
index 000000000..0908ce96c
--- /dev/null
+++ b/plugin/pkg/proxy/persistent.go
@@ -0,0 +1,156 @@
+package proxy
+
+import (
+ "crypto/tls"
+ "sort"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// a persistConn hold the dns.Conn and the last used time.
+type persistConn struct {
+ c *dns.Conn
+ used time.Time
+}
+
+// Transport hold the persistent cache.
+type Transport struct {
+ avgDialTime int64 // kind of average time of dial time
+ conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls.
+ expire time.Duration // After this duration a connection is expired.
+ addr string
+ tlsConfig *tls.Config
+
+ dial chan string
+ yield chan *persistConn
+ ret chan *persistConn
+ stop chan bool
+}
+
+func newTransport(addr string) *Transport {
+ t := &Transport{
+ avgDialTime: int64(maxDialTimeout / 2),
+ conns: [typeTotalCount][]*persistConn{},
+ expire: defaultExpire,
+ addr: addr,
+ dial: make(chan string),
+ yield: make(chan *persistConn),
+ ret: make(chan *persistConn),
+ stop: make(chan bool),
+ }
+ return t
+}
+
+// connManagers manages the persistent connection cache for UDP and TCP.
+func (t *Transport) connManager() {
+ ticker := time.NewTicker(defaultExpire)
+ defer ticker.Stop()
+Wait:
+ for {
+ select {
+ case proto := <-t.dial:
+ transtype := stringToTransportType(proto)
+ // take the last used conn - complexity O(1)
+ if stack := t.conns[transtype]; len(stack) > 0 {
+ pc := stack[len(stack)-1]
+ if time.Since(pc.used) < t.expire {
+ // Found one, remove from pool and return this conn.
+ t.conns[transtype] = stack[:len(stack)-1]
+ t.ret <- pc
+ continue Wait
+ }
+ // clear entire cache if the last conn is expired
+ t.conns[transtype] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
+ }
+ t.ret <- nil
+
+ case pc := <-t.yield:
+ transtype := t.transportTypeFromConn(pc)
+ t.conns[transtype] = append(t.conns[transtype], pc)
+
+ case <-ticker.C:
+ t.cleanup(false)
+
+ case <-t.stop:
+ t.cleanup(true)
+ close(t.ret)
+ return
+ }
+ }
+}
+
+// closeConns closes connections.
+func closeConns(conns []*persistConn) {
+ for _, pc := range conns {
+ pc.c.Close()
+ }
+}
+
+// cleanup removes connections from cache.
+func (t *Transport) cleanup(all bool) {
+ staleTime := time.Now().Add(-t.expire)
+ for transtype, stack := range t.conns {
+ if len(stack) == 0 {
+ continue
+ }
+ if all {
+ t.conns[transtype] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
+ continue
+ }
+ if stack[0].used.After(staleTime) {
+ continue
+ }
+
+ // connections in stack are sorted by "used"
+ good := sort.Search(len(stack), func(i int) bool {
+ return stack[i].used.After(staleTime)
+ })
+ t.conns[transtype] = stack[good:]
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack[:good])
+ }
+}
+
+// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible.
+const yieldTimeout = 25 * time.Millisecond
+
+// Yield returns the connection to transport for reuse.
+func (t *Transport) Yield(pc *persistConn) {
+ pc.used = time.Now() // update used time
+
+ // Make this non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This
+ // blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning
+ // these connection is an optimization anyway.
+ select {
+ case t.yield <- pc:
+ return
+ case <-time.After(yieldTimeout):
+ return
+ }
+}
+
+// Start starts the transport's connection manager.
+func (t *Transport) Start() { go t.connManager() }
+
+// Stop stops the transport's connection manager.
+func (t *Transport) Stop() { close(t.stop) }
+
+// SetExpire sets the connection expire time in transport.
+func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
+
+// SetTLSConfig sets the TLS config in transport.
+func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
+
+const (
+ defaultExpire = 10 * time.Second
+ minDialTimeout = 1 * time.Second
+ maxDialTimeout = 30 * time.Second
+)
diff --git a/plugin/pkg/proxy/persistent_test.go b/plugin/pkg/proxy/persistent_test.go
new file mode 100644
index 000000000..c78bd7f1f
--- /dev/null
+++ b/plugin/pkg/proxy/persistent_test.go
@@ -0,0 +1,109 @@
+package proxy
+
+import (
+ "testing"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+
+ "github.com/miekg/dns"
+)
+
+func TestCached(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, cache1, _ := tr.Dial("udp")
+ c2, cache2, _ := tr.Dial("udp")
+
+ if cache1 || cache2 {
+ t.Errorf("Expected non-cached connection")
+ }
+
+ tr.Yield(c1)
+ tr.Yield(c2)
+ c3, cached3, _ := tr.Dial("udp")
+ if !cached3 {
+ t.Error("Expected cached connection (c3)")
+ }
+ if c2 != c3 {
+ t.Error("Expected c2 == c3")
+ }
+
+ tr.Yield(c3)
+
+ // dial another protocol
+ c4, cached4, _ := tr.Dial("tcp")
+ if cached4 {
+ t.Errorf("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestCleanupByTimer(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+ tr.SetExpire(100 * time.Millisecond)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, _, _ := tr.Dial("udp")
+ c2, _, _ := tr.Dial("udp")
+ tr.Yield(c1)
+ time.Sleep(10 * time.Millisecond)
+ tr.Yield(c2)
+
+ time.Sleep(120 * time.Millisecond)
+ c3, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c3)")
+ }
+ tr.Yield(c3)
+
+ time.Sleep(120 * time.Millisecond)
+ c4, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestCleanupAll(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr)
+
+ c1, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+ c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+ c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
+
+ tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}}
+
+ if len(tr.conns[typeUDP]) != 3 {
+ t.Error("Expected 3 connections")
+ }
+ tr.cleanup(true)
+
+ if len(tr.conns[typeUDP]) > 0 {
+ t.Error("Expected no cached connections")
+ }
+}
diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go
new file mode 100644
index 000000000..be521fe05
--- /dev/null
+++ b/plugin/pkg/proxy/proxy.go
@@ -0,0 +1,98 @@
+package proxy
+
+import (
+ "crypto/tls"
+ "runtime"
+ "sync/atomic"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/log"
+ "github.com/coredns/coredns/plugin/pkg/up"
+)
+
+// Proxy defines an upstream host.
+type Proxy struct {
+ fails uint32
+ addr string
+
+ transport *Transport
+
+ readTimeout time.Duration
+
+ // health checking
+ probe *up.Probe
+ health HealthChecker
+}
+
+// NewProxy returns a new proxy.
+func NewProxy(addr, trans string) *Proxy {
+ p := &Proxy{
+ addr: addr,
+ fails: 0,
+ probe: up.New(),
+ readTimeout: 2 * time.Second,
+ transport: newTransport(addr),
+ }
+ p.health = NewHealthChecker(trans, true, ".")
+ runtime.SetFinalizer(p, (*Proxy).finalizer)
+ return p
+}
+
+func (p *Proxy) Addr() string { return p.addr }
+
+// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client.
+func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
+ p.transport.SetTLSConfig(cfg)
+ p.health.SetTLSConfig(cfg)
+}
+
+// SetExpire sets the expire duration in the lower p.transport.
+func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
+
+func (p *Proxy) GetHealthchecker() HealthChecker {
+ return p.health
+}
+
+func (p *Proxy) Fails() uint32 {
+ return atomic.LoadUint32(&p.fails)
+}
+
+// Healthcheck kicks of a round of health checks for this proxy.
+func (p *Proxy) Healthcheck() {
+ if p.health == nil {
+ log.Warning("No healthchecker")
+ return
+ }
+
+ p.probe.Do(func() error {
+ return p.health.Check(p)
+ })
+}
+
+// Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
+func (p *Proxy) Down(maxfails uint32) bool {
+ if maxfails == 0 {
+ return false
+ }
+
+ fails := atomic.LoadUint32(&p.fails)
+ return fails > maxfails
+}
+
+// Stop close stops the health checking goroutine.
+func (p *Proxy) Stop() { p.probe.Stop() }
+func (p *Proxy) finalizer() { p.transport.Stop() }
+
+// Start starts the proxy's healthchecking.
+func (p *Proxy) Start(duration time.Duration) {
+ p.probe.Start(duration)
+ p.transport.Start()
+}
+
+func (p *Proxy) SetReadTimeout(duration time.Duration) {
+ p.readTimeout = duration
+}
+
+const (
+ maxTimeout = 2 * time.Second
+)
diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go
new file mode 100644
index 000000000..274e9679d
--- /dev/null
+++ b/plugin/pkg/proxy/proxy_test.go
@@ -0,0 +1,99 @@
+package proxy
+
+import (
+ "context"
+ "crypto/tls"
+ "testing"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/pkg/transport"
+ "github.com/coredns/coredns/plugin/test"
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+func TestProxy(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr, transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+ p.Start(5 * time.Second)
+ m := new(dns.Msg)
+
+ m.SetQuestion("example.org.", dns.TypeA)
+
+ rec := dnstest.NewRecorder(&test.ResponseWriter{})
+ req := request.Request{Req: m, W: rec}
+
+ resp, err := p.Connect(context.Background(), req, Options{PreferUDP: true})
+ if err != nil {
+ t.Errorf("Failed to connect to testdnsserver: %s", err)
+ }
+
+ if x := resp.Answer[0].Header().Name; x != "example.org." {
+ t.Errorf("Expected %s, got %s", "example.org.", x)
+ }
+}
+
+func TestProxyTLSFail(t *testing.T) {
+ // This is an udp/tcp test server, so we shouldn't reach it with TLS.
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr, transport.TLS)
+ p.readTimeout = 10 * time.Millisecond
+ p.SetTLSConfig(&tls.Config{})
+ p.Start(5 * time.Second)
+ m := new(dns.Msg)
+
+ m.SetQuestion("example.org.", dns.TypeA)
+
+ rec := dnstest.NewRecorder(&test.ResponseWriter{})
+ req := request.Request{Req: m, W: rec}
+
+ _, err := p.Connect(context.Background(), req, Options{})
+ if err == nil {
+ t.Fatal("Expected *not* to receive reply, but got one")
+ }
+}
+
+func TestProtocolSelection(t *testing.T) {
+ p := NewProxy("bad_address", transport.DNS)
+ p.readTimeout = 10 * time.Millisecond
+
+ stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
+ stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}
+ ctx := context.TODO()
+
+ go func() {
+ p.Connect(ctx, stateUDP, Options{})
+ p.Connect(ctx, stateUDP, Options{ForceTCP: true})
+ p.Connect(ctx, stateUDP, Options{PreferUDP: true})
+ p.Connect(ctx, stateUDP, Options{PreferUDP: true, ForceTCP: true})
+ p.Connect(ctx, stateTCP, Options{})
+ p.Connect(ctx, stateTCP, Options{ForceTCP: true})
+ p.Connect(ctx, stateTCP, Options{PreferUDP: true})
+ p.Connect(ctx, stateTCP, Options{PreferUDP: true, ForceTCP: true})
+ }()
+
+ for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} {
+ proto := <-p.transport.dial
+ p.transport.ret <- nil
+ if proto != exp {
+ t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto)
+ }
+ }
+}
diff --git a/plugin/pkg/proxy/type.go b/plugin/pkg/proxy/type.go
new file mode 100644
index 000000000..10f3a4639
--- /dev/null
+++ b/plugin/pkg/proxy/type.go
@@ -0,0 +1,39 @@
+package proxy
+
+import (
+ "net"
+)
+
+type transportType int
+
+const (
+ typeUDP transportType = iota
+ typeTCP
+ typeTLS
+ typeTotalCount // keep this last
+)
+
+func stringToTransportType(s string) transportType {
+ switch s {
+ case "udp":
+ return typeUDP
+ case "tcp":
+ return typeTCP
+ case "tcp-tls":
+ return typeTLS
+ }
+
+ return typeUDP
+}
+
+func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
+ if _, ok := pc.c.Conn.(*net.UDPConn); ok {
+ return typeUDP
+ }
+
+ if t.tlsConfig == nil {
+ return typeTCP
+ }
+
+ return typeTLS
+}