diff options
-rw-r--r-- | plugin/forward/persistent.go | 67 | ||||
-rw-r--r-- | plugin/forward/persistent_test.go | 134 | ||||
-rw-r--r-- | plugin/forward/proxy.go | 5 |
3 files changed, 189 insertions, 17 deletions
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go index dc03002d3..e84c56ddd 100644 --- a/plugin/forward/persistent.go +++ b/plugin/forward/persistent.go @@ -3,6 +3,7 @@ package forward import ( "crypto/tls" "net" + "sort" "time" "github.com/miekg/dns" @@ -37,7 +38,6 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport { ret: make(chan *dns.Conn), stop: make(chan bool), } - go func() { t.connManager() }() return t } @@ -53,28 +53,26 @@ func (t *transport) len() int { // connManagers manages the persistent connection cache for UDP and TCP. func (t *transport) connManager() { - + ticker := time.NewTicker(t.expire) Wait: for { select { case proto := <-t.dial: - // Yes O(n), shouldn't put millions in here. We walk all connection until we find the first - // one that is usuable. - i := 0 - for i = 0; i < len(t.conns[proto]); i++ { - pc := t.conns[proto][i] + // take the last used conn - complexity O(1) + if stack := t.conns[proto]; len(stack) > 0 { + pc := stack[len(stack)-1] if time.Since(pc.used) < t.expire { // Found one, remove from pool and return this conn. - t.conns[proto] = t.conns[proto][i+1:] + t.conns[proto] = stack[:len(stack)-1] t.ret <- pc.c continue Wait } - // This conn has expired. Close it. - pc.c.Close() + // clear entire cache if the last conn is expired + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) } - - // Not conns were found. Connect to the upstream to create one. - t.conns[proto] = t.conns[proto][i:] SocketGauge.WithLabelValues(t.addr).Set(float64(t.len())) t.ret <- nil @@ -96,13 +94,53 @@ Wait: t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) + case <-ticker.C: + t.cleanup(false) + case <-t.stop: + t.cleanup(true) close(t.ret) return } } } +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for proto, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[proto] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} + // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { // If tls has been configured; use it. @@ -128,6 +166,9 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { // Yield return the connection to transport for reuse. func (t *transport) Yield(c *dns.Conn) { t.yield <- c } +// Start starts the transport's connection manager. +func (t *transport) Start() { go t.connManager() } + // Stop stops the transport's connection manager. func (t *transport) Stop() { close(t.stop) } diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go index 4b60c7c07..6aa8999f7 100644 --- a/plugin/forward/persistent_test.go +++ b/plugin/forward/persistent_test.go @@ -2,13 +2,14 @@ package forward import ( "testing" + "time" "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/miekg/dns" ) -func TestPersistent(t *testing.T) { +func TestCached(t *testing.T) { s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) @@ -17,17 +18,144 @@ func TestPersistent(t *testing.T) { defer s.Close() tr := newTransport(s.Addr, nil /* no TLS */) + tr.Start() defer tr.Stop() c1, cache1, _ := tr.Dial("udp") c2, cache2, _ := tr.Dial("udp") - c3, cache3, _ := tr.Dial("udp") - if cache1 || cache2 || cache3 { + if cache1 || cache2 { t.Errorf("Expected non-cached connection") } tr.Yield(c1) tr.Yield(c2) + c3, cached3, _ := tr.Dial("udp") + if !cached3 { + t.Error("Expected cached connection (c3)") + } + if c2 != c3 { + t.Error("Expected c2 == c3") + } + + tr.Yield(c3) + + // dial another protocol + c4, cached4, _ := tr.Dial("tcp") + if cached4 { + t.Errorf("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestCleanupByTimer(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr, nil /* no TLS */) + tr.SetExpire(100 * time.Millisecond) + tr.Start() + defer tr.Stop() + + c1, _, _ := tr.Dial("udp") + c2, _, _ := tr.Dial("udp") + tr.Yield(c1) + time.Sleep(10 * time.Millisecond) + tr.Yield(c2) + + time.Sleep(120 * time.Millisecond) + c3, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c3)") + } + tr.Yield(c3) + + time.Sleep(120 * time.Millisecond) + c4, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestPartialCleanup(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr, nil /* no TLS */) + tr.SetExpire(100 * time.Millisecond) + tr.Start() + defer tr.Stop() + + c1, _, _ := tr.Dial("udp") + c2, _, _ := tr.Dial("udp") + c3, _, _ := tr.Dial("udp") + c4, _, _ := tr.Dial("udp") + c5, _, _ := tr.Dial("udp") + + tr.Yield(c1) + time.Sleep(10 * time.Millisecond) + tr.Yield(c2) + time.Sleep(10 * time.Millisecond) tr.Yield(c3) + time.Sleep(50 * time.Millisecond) + tr.Yield(c4) + time.Sleep(10 * time.Millisecond) + tr.Yield(c5) + time.Sleep(40 * time.Millisecond) + + c6, _, _ := tr.Dial("udp") + if c6 != c5 { + t.Errorf("Expected c6 == c5") + } + c7, _, _ := tr.Dial("udp") + if c7 != c4 { + t.Errorf("Expected c7 == c4") + } + c8, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c8)") + } + + tr.Yield(c6) + tr.Yield(c7) + tr.Yield(c8) +} + +func TestCleanupAll(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr, nil /* no TLS */) + + c1, _ := dns.DialTimeout("udp", tr.addr, dialTimeout) + c2, _ := dns.DialTimeout("udp", tr.addr, dialTimeout) + c3, _ := dns.DialTimeout("udp", tr.addr, dialTimeout) + + tr.conns["udp"] = []*persistConn{ + {c1, time.Now()}, + {c2, time.Now()}, + {c3, time.Now()}, + } + + if tr.len() != 3 { + t.Error("Expected 3 connections") + } + tr.cleanup(true) + + if tr.len() > 0 { + t.Error("Expected no cached connections") + } } diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index a162ace1b..5e85d4969 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -100,7 +100,10 @@ func (p *Proxy) finalizer() { } // start starts the proxy's healthchecking. -func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) } +func (p *Proxy) start(duration time.Duration) { + p.probe.Start(duration) + p.transport.Start() +} const ( dialTimeout = 4 * time.Second |