aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--plugin/forward/persistent.go67
-rw-r--r--plugin/forward/persistent_test.go134
-rw-r--r--plugin/forward/proxy.go5
3 files changed, 189 insertions, 17 deletions
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go
index dc03002d3..e84c56ddd 100644
--- a/plugin/forward/persistent.go
+++ b/plugin/forward/persistent.go
@@ -3,6 +3,7 @@ package forward
import (
"crypto/tls"
"net"
+ "sort"
"time"
"github.com/miekg/dns"
@@ -37,7 +38,6 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport {
ret: make(chan *dns.Conn),
stop: make(chan bool),
}
- go func() { t.connManager() }()
return t
}
@@ -53,28 +53,26 @@ func (t *transport) len() int {
// connManagers manages the persistent connection cache for UDP and TCP.
func (t *transport) connManager() {
-
+ ticker := time.NewTicker(t.expire)
Wait:
for {
select {
case proto := <-t.dial:
- // Yes O(n), shouldn't put millions in here. We walk all connection until we find the first
- // one that is usuable.
- i := 0
- for i = 0; i < len(t.conns[proto]); i++ {
- pc := t.conns[proto][i]
+ // take the last used conn - complexity O(1)
+ if stack := t.conns[proto]; len(stack) > 0 {
+ pc := stack[len(stack)-1]
if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
- t.conns[proto] = t.conns[proto][i+1:]
+ t.conns[proto] = stack[:len(stack)-1]
t.ret <- pc.c
continue Wait
}
- // This conn has expired. Close it.
- pc.c.Close()
+ // clear entire cache if the last conn is expired
+ t.conns[proto] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
}
-
- // Not conns were found. Connect to the upstream to create one.
- t.conns[proto] = t.conns[proto][i:]
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
t.ret <- nil
@@ -96,13 +94,53 @@ Wait:
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
+ case <-ticker.C:
+ t.cleanup(false)
+
case <-t.stop:
+ t.cleanup(true)
close(t.ret)
return
}
}
}
+// closeConns closes connections.
+func closeConns(conns []*persistConn) {
+ for _, pc := range conns {
+ pc.c.Close()
+ }
+}
+
+// cleanup removes connections from cache.
+func (t *transport) cleanup(all bool) {
+ staleTime := time.Now().Add(-t.expire)
+ for proto, stack := range t.conns {
+ if len(stack) == 0 {
+ continue
+ }
+ if all {
+ t.conns[proto] = nil
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack)
+ continue
+ }
+ if stack[0].used.After(staleTime) {
+ continue
+ }
+
+ // connections in stack are sorted by "used"
+ good := sort.Search(len(stack), func(i int) bool {
+ return stack[i].used.After(staleTime)
+ })
+ t.conns[proto] = stack[good:]
+ // now, the connections being passed to closeConns() are not reachable from
+ // transport methods anymore. So, it's safe to close them in a separate goroutine
+ go closeConns(stack[:good])
+ }
+}
+
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
// If tls has been configured; use it.
@@ -128,6 +166,9 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
// Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) { t.yield <- c }
+// Start starts the transport's connection manager.
+func (t *transport) Start() { go t.connManager() }
+
// Stop stops the transport's connection manager.
func (t *transport) Stop() { close(t.stop) }
diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go
index 4b60c7c07..6aa8999f7 100644
--- a/plugin/forward/persistent_test.go
+++ b/plugin/forward/persistent_test.go
@@ -2,13 +2,14 @@ package forward
import (
"testing"
+ "time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/miekg/dns"
)
-func TestPersistent(t *testing.T) {
+func TestCached(t *testing.T) {
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
@@ -17,17 +18,144 @@ func TestPersistent(t *testing.T) {
defer s.Close()
tr := newTransport(s.Addr, nil /* no TLS */)
+ tr.Start()
defer tr.Stop()
c1, cache1, _ := tr.Dial("udp")
c2, cache2, _ := tr.Dial("udp")
- c3, cache3, _ := tr.Dial("udp")
- if cache1 || cache2 || cache3 {
+ if cache1 || cache2 {
t.Errorf("Expected non-cached connection")
}
tr.Yield(c1)
tr.Yield(c2)
+ c3, cached3, _ := tr.Dial("udp")
+ if !cached3 {
+ t.Error("Expected cached connection (c3)")
+ }
+ if c2 != c3 {
+ t.Error("Expected c2 == c3")
+ }
+
+ tr.Yield(c3)
+
+ // dial another protocol
+ c4, cached4, _ := tr.Dial("tcp")
+ if cached4 {
+ t.Errorf("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestCleanupByTimer(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr, nil /* no TLS */)
+ tr.SetExpire(100 * time.Millisecond)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, _, _ := tr.Dial("udp")
+ c2, _, _ := tr.Dial("udp")
+ tr.Yield(c1)
+ time.Sleep(10 * time.Millisecond)
+ tr.Yield(c2)
+
+ time.Sleep(120 * time.Millisecond)
+ c3, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c3)")
+ }
+ tr.Yield(c3)
+
+ time.Sleep(120 * time.Millisecond)
+ c4, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c4)")
+ }
+ tr.Yield(c4)
+}
+
+func TestPartialCleanup(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr, nil /* no TLS */)
+ tr.SetExpire(100 * time.Millisecond)
+ tr.Start()
+ defer tr.Stop()
+
+ c1, _, _ := tr.Dial("udp")
+ c2, _, _ := tr.Dial("udp")
+ c3, _, _ := tr.Dial("udp")
+ c4, _, _ := tr.Dial("udp")
+ c5, _, _ := tr.Dial("udp")
+
+ tr.Yield(c1)
+ time.Sleep(10 * time.Millisecond)
+ tr.Yield(c2)
+ time.Sleep(10 * time.Millisecond)
tr.Yield(c3)
+ time.Sleep(50 * time.Millisecond)
+ tr.Yield(c4)
+ time.Sleep(10 * time.Millisecond)
+ tr.Yield(c5)
+ time.Sleep(40 * time.Millisecond)
+
+ c6, _, _ := tr.Dial("udp")
+ if c6 != c5 {
+ t.Errorf("Expected c6 == c5")
+ }
+ c7, _, _ := tr.Dial("udp")
+ if c7 != c4 {
+ t.Errorf("Expected c7 == c4")
+ }
+ c8, cached, _ := tr.Dial("udp")
+ if cached {
+ t.Error("Expected non-cached connection (c8)")
+ }
+
+ tr.Yield(c6)
+ tr.Yield(c7)
+ tr.Yield(c8)
+}
+
+func TestCleanupAll(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ tr := newTransport(s.Addr, nil /* no TLS */)
+
+ c1, _ := dns.DialTimeout("udp", tr.addr, dialTimeout)
+ c2, _ := dns.DialTimeout("udp", tr.addr, dialTimeout)
+ c3, _ := dns.DialTimeout("udp", tr.addr, dialTimeout)
+
+ tr.conns["udp"] = []*persistConn{
+ {c1, time.Now()},
+ {c2, time.Now()},
+ {c3, time.Now()},
+ }
+
+ if tr.len() != 3 {
+ t.Error("Expected 3 connections")
+ }
+ tr.cleanup(true)
+
+ if tr.len() > 0 {
+ t.Error("Expected no cached connections")
+ }
}
diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go
index a162ace1b..5e85d4969 100644
--- a/plugin/forward/proxy.go
+++ b/plugin/forward/proxy.go
@@ -100,7 +100,10 @@ func (p *Proxy) finalizer() {
}
// start starts the proxy's healthchecking.
-func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
+func (p *Proxy) start(duration time.Duration) {
+ p.probe.Start(duration)
+ p.transport.Start()
+}
const (
dialTimeout = 4 * time.Second