aboutsummaryrefslogtreecommitdiff
path: root/plugin/forward/persistent.go
blob: 6a7c4464e41eb9641c2dcfc97c478742b6b541ee (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package forward

import (
	"net"
	"time"

	"github.com/miekg/dns"
)

// a persistConn hold the dns.Conn and the last used time.
type persistConn struct {
	c    *dns.Conn
	used time.Time
}

// connErr is used to communicate the connection manager.
type connErr struct {
	c   *dns.Conn
	err error
}

// transport hold the persistent cache.
type transport struct {
	conns map[string][]*persistConn //  Buckets for udp, tcp and tcp-tls.
	host  *host

	dial  chan string
	yield chan connErr
	ret   chan connErr

	// Aid in testing, gets length of cache in data-race safe manner.
	lenc    chan bool
	lencOut chan int

	stop chan bool
}

func newTransport(h *host) *transport {
	t := &transport{
		conns:   make(map[string][]*persistConn),
		host:    h,
		dial:    make(chan string),
		yield:   make(chan connErr),
		ret:     make(chan connErr),
		stop:    make(chan bool),
		lenc:    make(chan bool),
		lencOut: make(chan int),
	}
	go t.connManager()
	return t
}

// len returns the number of connection, used for metrics. Can only be safely
// used inside connManager() because of races.
func (t *transport) len() int {
	l := 0
	for _, conns := range t.conns {
		l += len(conns)
	}
	return l
}

// Len returns the number of connections in the cache.
func (t *transport) Len() int {
	t.lenc <- true
	l := <-t.lencOut
	return l
}

// connManagers manages the persistent connection cache for UDP and TCP.
func (t *transport) connManager() {

Wait:
	for {
		select {
		case proto := <-t.dial:
			// Yes O(n), shouldn't put millions in here. We walk all connection until we find the first
			// one that is usuable.
			i := 0
			for i = 0; i < len(t.conns[proto]); i++ {
				pc := t.conns[proto][i]
				if time.Since(pc.used) < t.host.expire {
					// Found one, remove from pool and return this conn.
					t.conns[proto] = t.conns[proto][i+1:]
					t.ret <- connErr{pc.c, nil}
					continue Wait
				}
				// This conn has expired. Close it.
				pc.c.Close()
			}

			// Not conns were found. Connect to the upstream to create one.
			t.conns[proto] = t.conns[proto][i:]
			SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len()))

			go func() {
				if proto != "tcp-tls" {
					c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout)
					t.ret <- connErr{c, err}
					return
				}

				c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout)
				t.ret <- connErr{c, err}
			}()

		case conn := <-t.yield:

			SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1))

			// no proto here, infer from config and conn
			if _, ok := conn.c.Conn.(*net.UDPConn); ok {
				t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()})
				continue Wait
			}

			if t.host.tlsConfig == nil {
				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
				continue Wait
			}

			t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()})

		case <-t.stop:
			return

		case <-t.lenc:
			l := 0
			for _, conns := range t.conns {
				l += len(conns)
			}
			t.lencOut <- l
		}
	}
}

func (t *transport) Dial(proto string) (*dns.Conn, error) {
	t.dial <- proto
	c := <-t.ret
	return c.c, c.err
}

func (t *transport) Yield(c *dns.Conn) {
	t.yield <- connErr{c, nil}
}

// Stop stops the transports.
func (t *transport) Stop() { t.stop <- true }