aboutsummaryrefslogtreecommitdiff
path: root/plugin/forward/persistent.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/forward/persistent.go')
-rw-r--r--plugin/forward/persistent.go43
1 files changed, 31 insertions, 12 deletions
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go
index 6a7c4464e..7bf083b49 100644
--- a/plugin/forward/persistent.go
+++ b/plugin/forward/persistent.go
@@ -1,6 +1,7 @@
package forward
import (
+ "crypto/tls"
"net"
"time"
@@ -21,8 +22,10 @@ type connErr struct {
// transport hold the persistent cache.
type transport struct {
- conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
- host *host
+ conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
+ expire time.Duration // After this duration a connection is expired.
+ addr string
+ tlsConfig *tls.Config
dial chan string
yield chan connErr
@@ -35,10 +38,11 @@ type transport struct {
stop chan bool
}
-func newTransport(h *host) *transport {
+func newTransport(addr string, tlsConfig *tls.Config) *transport {
t := &transport{
conns: make(map[string][]*persistConn),
- host: h,
+ expire: defaultExpire,
+ addr: addr,
dial: make(chan string),
yield: make(chan connErr),
ret: make(chan connErr),
@@ -51,7 +55,7 @@ func newTransport(h *host) *transport {
}
// len returns the number of connection, used for metrics. Can only be safely
-// used inside connManager() because of races.
+// used inside connManager() because of data races.
func (t *transport) len() int {
l := 0
for _, conns := range t.conns {
@@ -79,7 +83,7 @@ Wait:
i := 0
for i = 0; i < len(t.conns[proto]); i++ {
pc := t.conns[proto][i]
- if time.Since(pc.used) < t.host.expire {
+ if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
t.conns[proto] = t.conns[proto][i+1:]
t.ret <- connErr{pc.c, nil}
@@ -91,22 +95,22 @@ Wait:
// Not conns were found. Connect to the upstream to create one.
t.conns[proto] = t.conns[proto][i:]
- SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len()))
+ SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
go func() {
if proto != "tcp-tls" {
- c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout)
+ c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
t.ret <- connErr{c, err}
return
}
- c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout)
+ c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
t.ret <- connErr{c, err}
}()
case conn := <-t.yield:
- SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1))
+ SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
// no proto here, infer from config and conn
if _, ok := conn.c.Conn.(*net.UDPConn); ok {
@@ -114,7 +118,7 @@ Wait:
continue Wait
}
- if t.host.tlsConfig == nil {
+ if t.tlsConfig == nil {
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
continue Wait
}
@@ -134,15 +138,30 @@ Wait:
}
}
+// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *transport) Dial(proto string) (*dns.Conn, error) {
+ // If tls has been configured; use it.
+ if t.tlsConfig != nil {
+ proto = "tcp-tls"
+ }
+
t.dial <- proto
c := <-t.ret
return c.c, c.err
}
+// Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) {
t.yield <- connErr{c, nil}
}
-// Stop stops the transports.
+// Stop stops the transport's connection manager.
func (t *transport) Stop() { t.stop <- true }
+
+// SetExpire sets the connection expire time in transport.
+func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
+
+// SetTLSConfig sets the TLS config in transport.
+func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
+
+const defaultExpire = 10 * time.Second