diff options
Diffstat (limited to 'plugin/forward/persistent.go')
-rw-r--r-- | plugin/forward/persistent.go | 67 |
1 files changed, 54 insertions, 13 deletions
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go index dc03002d3..e84c56ddd 100644 --- a/plugin/forward/persistent.go +++ b/plugin/forward/persistent.go @@ -3,6 +3,7 @@ package forward import ( "crypto/tls" "net" + "sort" "time" "github.com/miekg/dns" @@ -37,7 +38,6 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport { ret: make(chan *dns.Conn), stop: make(chan bool), } - go func() { t.connManager() }() return t } @@ -53,28 +53,26 @@ func (t *transport) len() int { // connManagers manages the persistent connection cache for UDP and TCP. func (t *transport) connManager() { - + ticker := time.NewTicker(t.expire) Wait: for { select { case proto := <-t.dial: - // Yes O(n), shouldn't put millions in here. We walk all connection until we find the first - // one that is usuable. - i := 0 - for i = 0; i < len(t.conns[proto]); i++ { - pc := t.conns[proto][i] + // take the last used conn - complexity O(1) + if stack := t.conns[proto]; len(stack) > 0 { + pc := stack[len(stack)-1] if time.Since(pc.used) < t.expire { // Found one, remove from pool and return this conn. - t.conns[proto] = t.conns[proto][i+1:] + t.conns[proto] = stack[:len(stack)-1] t.ret <- pc.c continue Wait } - // This conn has expired. Close it. - pc.c.Close() + // clear entire cache if the last conn is expired + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) } - - // Not conns were found. Connect to the upstream to create one. - t.conns[proto] = t.conns[proto][i:] SocketGauge.WithLabelValues(t.addr).Set(float64(t.len())) t.ret <- nil @@ -96,13 +94,53 @@ Wait: t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) + case <-ticker.C: + t.cleanup(false) + case <-t.stop: + t.cleanup(true) close(t.ret) return } } } +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for proto, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[proto] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} + // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { // If tls has been configured; use it. @@ -128,6 +166,9 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { // Yield return the connection to transport for reuse. func (t *transport) Yield(c *dns.Conn) { t.yield <- c } +// Start starts the transport's connection manager. +func (t *transport) Start() { go t.connManager() } + // Stop stops the transport's connection manager. func (t *transport) Stop() { close(t.stop) } |