diff options
author | 2023-03-24 12:55:51 +0000 | |
---|---|---|
committer | 2023-03-24 08:55:51 -0400 | |
commit | f823825f8a34edb85d5d18cd5d2f6f850adf408e (patch) | |
tree | 79d241ab9b4c7c343d806f4041c8efccbe3f9ca0 /plugin/pkg/proxy | |
parent | 47dceabfc6465ba6c5d41472d6602d4ad5c9fb1b (diff) | |
download | coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.gz coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.zst coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.zip |
plugin/forward: Allow Proxy to be used outside of forward plugin. (#5951)
* plugin/forward: Move Proxy into pkg/plugin/proxy, to allow forward.Proxy to be used outside of forward plugin.
Signed-off-by: Patrick Downey <patrick.downey@dioadconsulting.com>
Diffstat (limited to 'plugin/pkg/proxy')
-rw-r--r-- | plugin/pkg/proxy/connect.go | 152 | ||||
-rw-r--r-- | plugin/pkg/proxy/errors.go | 26 | ||||
-rw-r--r-- | plugin/pkg/proxy/health.go | 131 | ||||
-rw-r--r-- | plugin/pkg/proxy/health_test.go | 153 | ||||
-rw-r--r-- | plugin/pkg/proxy/metrics.go | 49 | ||||
-rw-r--r-- | plugin/pkg/proxy/persistent.go | 156 | ||||
-rw-r--r-- | plugin/pkg/proxy/persistent_test.go | 109 | ||||
-rw-r--r-- | plugin/pkg/proxy/proxy.go | 98 | ||||
-rw-r--r-- | plugin/pkg/proxy/proxy_test.go | 99 | ||||
-rw-r--r-- | plugin/pkg/proxy/type.go | 39 |
10 files changed, 1012 insertions, 0 deletions
diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go new file mode 100644 index 000000000..29274d92d --- /dev/null +++ b/plugin/pkg/proxy/connect.go @@ -0,0 +1,152 @@ +// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package proxy + +import ( + "context" + "io" + "strconv" + "sync/atomic" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *Transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} + +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *Transport) Dial(proto string) (*persistConn, bool, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" + } + + t.dial <- proto + pc := <-t.ret + + if pc != nil { + ConnCacheHitsCount.WithLabelValues(t.addr, proto).Add(1) + return pc, true, nil + } + ConnCacheMissesCount.WithLabelValues(t.addr, proto).Add(1) + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err +} + +// Connect selects an upstream, sends the request and waits for a response. +func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { + start := time.Now() + + proto := "" + switch { + case opts.ForceTCP: // TCP flag has precedence over UDP flag + proto = "tcp" + case opts.PreferUDP: + proto = "udp" + default: + proto = state.Proto() + } + + pc, cached, err := p.transport.Dial(proto) + if err != nil { + return nil, err + } + + // Set buffer size correctly for this client. + pc.c.UDPSize = uint16(state.Size()) + if pc.c.UDPSize < 512 { + pc.c.UDPSize = 512 + } + + pc.c.SetWriteDeadline(time.Now().Add(maxTimeout)) + // records the origin Id before upstream. + originId := state.Req.Id + state.Req.Id = dns.Id() + defer func() { + state.Req.Id = originId + }() + + if err := pc.c.WriteMsg(state.Req); err != nil { + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + return nil, err + } + + var ret *dns.Msg + pc.c.SetReadDeadline(time.Now().Add(p.readTimeout)) + for { + ret, err = pc.c.ReadMsg() + if err != nil { + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + // recovery the origin Id after upstream. + if ret != nil { + ret.Id = originId + } + return ret, err + } + // drop out-of-order responses + if state.Req.Id == ret.Id { + break + } + } + // recovery the origin Id after upstream. + ret.Id = originId + + p.transport.Yield(pc) + + rc, ok := dns.RcodeToString[ret.Rcode] + if !ok { + rc = strconv.Itoa(ret.Rcode) + } + + RequestCount.WithLabelValues(p.addr).Add(1) + RcodeCount.WithLabelValues(rc, p.addr).Add(1) + RequestDuration.WithLabelValues(p.addr, rc).Observe(time.Since(start).Seconds()) + + return ret, nil +} + +const cumulativeAvgWeight = 4 diff --git a/plugin/pkg/proxy/errors.go b/plugin/pkg/proxy/errors.go new file mode 100644 index 000000000..461236423 --- /dev/null +++ b/plugin/pkg/proxy/errors.go @@ -0,0 +1,26 @@ +package proxy + +import ( + "errors" +) + +var ( + // ErrNoHealthy means no healthy proxies left. + ErrNoHealthy = errors.New("no healthy proxies") + // ErrNoForward means no forwarder defined. + ErrNoForward = errors.New("no forwarder defined") + // ErrCachedClosed means cached connection was closed by peer. + ErrCachedClosed = errors.New("cached connection was closed by peer") +) + +// Options holds various Options that can be set. +type Options struct { + // ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag + ForceTCP bool + // PreferUDP use UDP protocol for upstream DNS request. + PreferUDP bool + // HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests + HCRecursionDesired bool + // HCDomain sets domain for Proxy healthcheck requests + HCDomain string +} diff --git a/plugin/pkg/proxy/health.go b/plugin/pkg/proxy/health.go new file mode 100644 index 000000000..e87104a13 --- /dev/null +++ b/plugin/pkg/proxy/health.go @@ -0,0 +1,131 @@ +package proxy + +import ( + "crypto/tls" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +// HealthChecker checks the upstream health. +type HealthChecker interface { + Check(*Proxy) error + SetTLSConfig(*tls.Config) + GetTLSConfig() *tls.Config + SetRecursionDesired(bool) + GetRecursionDesired() bool + SetDomain(domain string) + GetDomain() string + SetTCPTransport() + GetReadTimeout() time.Duration + SetReadTimeout(time.Duration) + GetWriteTimeout() time.Duration + SetWriteTimeout(time.Duration) +} + +// dnsHc is a health checker for a DNS endpoint (DNS, and DoT). +type dnsHc struct { + c *dns.Client + recursionDesired bool + domain string +} + +// NewHealthChecker returns a new HealthChecker based on transport. +func NewHealthChecker(trans string, recursionDesired bool, domain string) HealthChecker { + switch trans { + case transport.DNS, transport.TLS: + c := new(dns.Client) + c.Net = "udp" + c.ReadTimeout = 1 * time.Second + c.WriteTimeout = 1 * time.Second + + return &dnsHc{ + c: c, + recursionDesired: recursionDesired, + domain: domain, + } + } + + log.Warningf("No healthchecker for transport %q", trans) + return nil +} + +func (h *dnsHc) SetTLSConfig(cfg *tls.Config) { + h.c.Net = "tcp-tls" + h.c.TLSConfig = cfg +} + +func (h *dnsHc) GetTLSConfig() *tls.Config { + return h.c.TLSConfig +} + +func (h *dnsHc) SetRecursionDesired(recursionDesired bool) { + h.recursionDesired = recursionDesired +} +func (h *dnsHc) GetRecursionDesired() bool { + return h.recursionDesired +} + +func (h *dnsHc) SetDomain(domain string) { + h.domain = domain +} +func (h *dnsHc) GetDomain() string { + return h.domain +} + +func (h *dnsHc) SetTCPTransport() { + h.c.Net = "tcp" +} + +func (h *dnsHc) GetReadTimeout() time.Duration { + return h.c.ReadTimeout +} + +func (h *dnsHc) SetReadTimeout(t time.Duration) { + h.c.ReadTimeout = t +} + +func (h *dnsHc) GetWriteTimeout() time.Duration { + return h.c.WriteTimeout +} + +func (h *dnsHc) SetWriteTimeout(t time.Duration) { + h.c.WriteTimeout = t +} + +// For HC, we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty +// replies are considered fails, basically anything else constitutes a healthy upstream. + +// Check is used as the up.Func in the up.Probe. +func (h *dnsHc) Check(p *Proxy) error { + err := h.send(p.addr) + if err != nil { + HealthcheckFailureCount.WithLabelValues(p.addr).Add(1) + atomic.AddUint32(&p.fails, 1) + return err + } + + atomic.StoreUint32(&p.fails, 0) + return nil +} + +func (h *dnsHc) send(addr string) error { + ping := new(dns.Msg) + ping.SetQuestion(h.domain, dns.TypeNS) + ping.MsgHdr.RecursionDesired = h.recursionDesired + + m, _, err := h.c.Exchange(ping, addr) + // If we got a header, we're alright, basically only care about I/O errors 'n stuff. + if err != nil && m != nil { + // Silly check, something sane came back. + if m.Response || m.Opcode == dns.OpcodeQuery { + err = nil + } + } + + return err +} diff --git a/plugin/pkg/proxy/health_test.go b/plugin/pkg/proxy/health_test.go new file mode 100644 index 000000000..c1b5270ad --- /dev/null +++ b/plugin/pkg/proxy/health_test.go @@ -0,0 +1,153 @@ +package proxy + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +func TestHealth(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, true, "") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthTCP(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, true, "") + hc.SetTCPTransport() + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthNoRecursion(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == false { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, false, "") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1) + } +} + +func TestHealthTimeout(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // timeout + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, false, "") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err == nil { + t.Errorf("expected error") + } +} + +func TestHealthDomain(t *testing.T) { + hcDomain := "example.org." + + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == hcDomain && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, true, hcDomain) + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(12 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1) + } +} diff --git a/plugin/pkg/proxy/metrics.go b/plugin/pkg/proxy/metrics.go new file mode 100644 index 000000000..148bc6edd --- /dev/null +++ b/plugin/pkg/proxy/metrics.go @@ -0,0 +1,49 @@ +package proxy + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Variables declared for monitoring. +var ( + RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "requests_total", + Help: "Counter of requests made per upstream.", + }, []string{"to"}) + RcodeCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "responses_total", + Help: "Counter of responses received per upstream.", + }, []string{"rcode", "to"}) + RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "request_duration_seconds", + Buckets: plugin.TimeBuckets, + Help: "Histogram of the time each request took.", + }, []string{"to", "rcode"}) + HealthcheckFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "healthcheck_failures_total", + Help: "Counter of the number of failed healthchecks.", + }, []string{"to"}) + ConnCacheHitsCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "conn_cache_hits_total", + Help: "Counter of connection cache hits per upstream and protocol.", + }, []string{"to", "proto"}) + ConnCacheMissesCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "conn_cache_misses_total", + Help: "Counter of connection cache misses per upstream and protocol.", + }, []string{"to", "proto"}) +) diff --git a/plugin/pkg/proxy/persistent.go b/plugin/pkg/proxy/persistent.go new file mode 100644 index 000000000..0908ce96c --- /dev/null +++ b/plugin/pkg/proxy/persistent.go @@ -0,0 +1,156 @@ +package proxy + +import ( + "crypto/tls" + "sort" + "time" + + "github.com/miekg/dns" +) + +// a persistConn hold the dns.Conn and the last used time. +type persistConn struct { + c *dns.Conn + used time.Time +} + +// Transport hold the persistent cache. +type Transport struct { + avgDialTime int64 // kind of average time of dial time + conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls. + expire time.Duration // After this duration a connection is expired. + addr string + tlsConfig *tls.Config + + dial chan string + yield chan *persistConn + ret chan *persistConn + stop chan bool +} + +func newTransport(addr string) *Transport { + t := &Transport{ + avgDialTime: int64(maxDialTimeout / 2), + conns: [typeTotalCount][]*persistConn{}, + expire: defaultExpire, + addr: addr, + dial: make(chan string), + yield: make(chan *persistConn), + ret: make(chan *persistConn), + stop: make(chan bool), + } + return t +} + +// connManagers manages the persistent connection cache for UDP and TCP. +func (t *Transport) connManager() { + ticker := time.NewTicker(defaultExpire) + defer ticker.Stop() +Wait: + for { + select { + case proto := <-t.dial: + transtype := stringToTransportType(proto) + // take the last used conn - complexity O(1) + if stack := t.conns[transtype]; len(stack) > 0 { + pc := stack[len(stack)-1] + if time.Since(pc.used) < t.expire { + // Found one, remove from pool and return this conn. + t.conns[transtype] = stack[:len(stack)-1] + t.ret <- pc + continue Wait + } + // clear entire cache if the last conn is expired + t.conns[transtype] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + } + t.ret <- nil + + case pc := <-t.yield: + transtype := t.transportTypeFromConn(pc) + t.conns[transtype] = append(t.conns[transtype], pc) + + case <-ticker.C: + t.cleanup(false) + + case <-t.stop: + t.cleanup(true) + close(t.ret) + return + } + } +} + +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *Transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for transtype, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[transtype] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[transtype] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} + +// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible. +const yieldTimeout = 25 * time.Millisecond + +// Yield returns the connection to transport for reuse. +func (t *Transport) Yield(pc *persistConn) { + pc.used = time.Now() // update used time + + // Make this non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This + // blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning + // these connection is an optimization anyway. + select { + case t.yield <- pc: + return + case <-time.After(yieldTimeout): + return + } +} + +// Start starts the transport's connection manager. +func (t *Transport) Start() { go t.connManager() } + +// Stop stops the transport's connection manager. +func (t *Transport) Stop() { close(t.stop) } + +// SetExpire sets the connection expire time in transport. +func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } + +// SetTLSConfig sets the TLS config in transport. +func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } + +const ( + defaultExpire = 10 * time.Second + minDialTimeout = 1 * time.Second + maxDialTimeout = 30 * time.Second +) diff --git a/plugin/pkg/proxy/persistent_test.go b/plugin/pkg/proxy/persistent_test.go new file mode 100644 index 000000000..c78bd7f1f --- /dev/null +++ b/plugin/pkg/proxy/persistent_test.go @@ -0,0 +1,109 @@ +package proxy + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + + "github.com/miekg/dns" +) + +func TestCached(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr) + tr.Start() + defer tr.Stop() + + c1, cache1, _ := tr.Dial("udp") + c2, cache2, _ := tr.Dial("udp") + + if cache1 || cache2 { + t.Errorf("Expected non-cached connection") + } + + tr.Yield(c1) + tr.Yield(c2) + c3, cached3, _ := tr.Dial("udp") + if !cached3 { + t.Error("Expected cached connection (c3)") + } + if c2 != c3 { + t.Error("Expected c2 == c3") + } + + tr.Yield(c3) + + // dial another protocol + c4, cached4, _ := tr.Dial("tcp") + if cached4 { + t.Errorf("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestCleanupByTimer(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr) + tr.SetExpire(100 * time.Millisecond) + tr.Start() + defer tr.Stop() + + c1, _, _ := tr.Dial("udp") + c2, _, _ := tr.Dial("udp") + tr.Yield(c1) + time.Sleep(10 * time.Millisecond) + tr.Yield(c2) + + time.Sleep(120 * time.Millisecond) + c3, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c3)") + } + tr.Yield(c3) + + time.Sleep(120 * time.Millisecond) + c4, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestCleanupAll(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr) + + c1, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + + tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}} + + if len(tr.conns[typeUDP]) != 3 { + t.Error("Expected 3 connections") + } + tr.cleanup(true) + + if len(tr.conns[typeUDP]) > 0 { + t.Error("Expected no cached connections") + } +} diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go new file mode 100644 index 000000000..be521fe05 --- /dev/null +++ b/plugin/pkg/proxy/proxy.go @@ -0,0 +1,98 @@ +package proxy + +import ( + "crypto/tls" + "runtime" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/up" +) + +// Proxy defines an upstream host. +type Proxy struct { + fails uint32 + addr string + + transport *Transport + + readTimeout time.Duration + + // health checking + probe *up.Probe + health HealthChecker +} + +// NewProxy returns a new proxy. +func NewProxy(addr, trans string) *Proxy { + p := &Proxy{ + addr: addr, + fails: 0, + probe: up.New(), + readTimeout: 2 * time.Second, + transport: newTransport(addr), + } + p.health = NewHealthChecker(trans, true, ".") + runtime.SetFinalizer(p, (*Proxy).finalizer) + return p +} + +func (p *Proxy) Addr() string { return p.addr } + +// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client. +func (p *Proxy) SetTLSConfig(cfg *tls.Config) { + p.transport.SetTLSConfig(cfg) + p.health.SetTLSConfig(cfg) +} + +// SetExpire sets the expire duration in the lower p.transport. +func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } + +func (p *Proxy) GetHealthchecker() HealthChecker { + return p.health +} + +func (p *Proxy) Fails() uint32 { + return atomic.LoadUint32(&p.fails) +} + +// Healthcheck kicks of a round of health checks for this proxy. +func (p *Proxy) Healthcheck() { + if p.health == nil { + log.Warning("No healthchecker") + return + } + + p.probe.Do(func() error { + return p.health.Check(p) + }) +} + +// Down returns true if this proxy is down, i.e. has *more* fails than maxfails. +func (p *Proxy) Down(maxfails uint32) bool { + if maxfails == 0 { + return false + } + + fails := atomic.LoadUint32(&p.fails) + return fails > maxfails +} + +// Stop close stops the health checking goroutine. +func (p *Proxy) Stop() { p.probe.Stop() } +func (p *Proxy) finalizer() { p.transport.Stop() } + +// Start starts the proxy's healthchecking. +func (p *Proxy) Start(duration time.Duration) { + p.probe.Start(duration) + p.transport.Start() +} + +func (p *Proxy) SetReadTimeout(duration time.Duration) { + p.readTimeout = duration +} + +const ( + maxTimeout = 2 * time.Second +) diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go new file mode 100644 index 000000000..274e9679d --- /dev/null +++ b/plugin/pkg/proxy/proxy_test.go @@ -0,0 +1,99 @@ +package proxy + +import ( + "context" + "crypto/tls" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestProxy(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + p.Start(5 * time.Second) + m := new(dns.Msg) + + m.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req := request.Request{Req: m, W: rec} + + resp, err := p.Connect(context.Background(), req, Options{PreferUDP: true}) + if err != nil { + t.Errorf("Failed to connect to testdnsserver: %s", err) + } + + if x := resp.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +func TestProxyTLSFail(t *testing.T) { + // This is an udp/tcp test server, so we shouldn't reach it with TLS. + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.TLS) + p.readTimeout = 10 * time.Millisecond + p.SetTLSConfig(&tls.Config{}) + p.Start(5 * time.Second) + m := new(dns.Msg) + + m.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req := request.Request{Req: m, W: rec} + + _, err := p.Connect(context.Background(), req, Options{}) + if err == nil { + t.Fatal("Expected *not* to receive reply, but got one") + } +} + +func TestProtocolSelection(t *testing.T) { + p := NewProxy("bad_address", transport.DNS) + p.readTimeout = 10 * time.Millisecond + + stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} + ctx := context.TODO() + + go func() { + p.Connect(ctx, stateUDP, Options{}) + p.Connect(ctx, stateUDP, Options{ForceTCP: true}) + p.Connect(ctx, stateUDP, Options{PreferUDP: true}) + p.Connect(ctx, stateUDP, Options{PreferUDP: true, ForceTCP: true}) + p.Connect(ctx, stateTCP, Options{}) + p.Connect(ctx, stateTCP, Options{ForceTCP: true}) + p.Connect(ctx, stateTCP, Options{PreferUDP: true}) + p.Connect(ctx, stateTCP, Options{PreferUDP: true, ForceTCP: true}) + }() + + for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { + proto := <-p.transport.dial + p.transport.ret <- nil + if proto != exp { + t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) + } + } +} diff --git a/plugin/pkg/proxy/type.go b/plugin/pkg/proxy/type.go new file mode 100644 index 000000000..10f3a4639 --- /dev/null +++ b/plugin/pkg/proxy/type.go @@ -0,0 +1,39 @@ +package proxy + +import ( + "net" +) + +type transportType int + +const ( + typeUDP transportType = iota + typeTCP + typeTLS + typeTotalCount // keep this last +) + +func stringToTransportType(s string) transportType { + switch s { + case "udp": + return typeUDP + case "tcp": + return typeTCP + case "tcp-tls": + return typeTLS + } + + return typeUDP +} + +func (t *Transport) transportTypeFromConn(pc *persistConn) transportType { + if _, ok := pc.c.Conn.(*net.UDPConn); ok { + return typeUDP + } + + if t.tlsConfig == nil { + return typeTCP + } + + return typeTLS +} |