diff options
Diffstat (limited to 'plugin/forward/connect.go')
-rw-r--r-- | plugin/forward/connect.go | 64 |
1 files changed, 54 insertions, 10 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index e4bf64f2b..fe6313e0e 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -16,21 +16,65 @@ import ( "github.com/miekg/dns" ) -func (p *Proxy) readTimeout() time.Duration { - rtt := time.Duration(atomic.LoadInt64(&p.avgRtt)) +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} - if rtt < minTimeout { - return minTimeout +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" } - if rtt < maxTimeout/2 { - return 2 * rtt + + t.dial <- proto + c := <-t.ret + + if c != nil { + return c, true, nil } - return maxTimeout + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, false, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, false, err +} + +func (p *Proxy) readTimeout() time.Duration { + return limitTimeout(&p.avgRtt, minTimeout, maxTimeout) } func (p *Proxy) updateRtt(newRtt time.Duration) { - rtt := time.Duration(atomic.LoadInt64(&p.avgRtt)) - atomic.AddInt64(&p.avgRtt, int64((newRtt-rtt)/rttCount)) + averageTimeout(&p.avgRtt, newRtt, cumulativeAvgWeight) } // Connect selects an upstream, sends the request and waits for a response. @@ -92,4 +136,4 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, me return ret, nil } -const rttCount = 4 +const cumulativeAvgWeight = 4 |