aboutsummaryrefslogtreecommitdiff
path: root/plugin/forward/connect.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/forward/connect.go')
-rw-r--r--plugin/forward/connect.go64
1 files changed, 54 insertions, 10 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
index e4bf64f2b..fe6313e0e 100644
--- a/plugin/forward/connect.go
+++ b/plugin/forward/connect.go
@@ -16,21 +16,65 @@ import (
"github.com/miekg/dns"
)
-func (p *Proxy) readTimeout() time.Duration {
- rtt := time.Duration(atomic.LoadInt64(&p.avgRtt))
+// limitTimeout is a utility function to auto-tune timeout values
+// average observed time is moved towards the last observed delay moderated by a weight
+// next timeout to use will be the double of the computed average, limited by min and max frame.
+func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
+ rt := time.Duration(atomic.LoadInt64(currentAvg))
+ if rt < minValue {
+ return minValue
+ }
+ if rt < maxValue/2 {
+ return 2 * rt
+ }
+ return maxValue
+}
+
+func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
+ dt := time.Duration(atomic.LoadInt64(currentAvg))
+ atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
+}
+
+func (t *transport) dialTimeout() time.Duration {
+ return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
+}
+
+func (t *transport) updateDialTimeout(newDialTime time.Duration) {
+ averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
+}
- if rtt < minTimeout {
- return minTimeout
+// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
+func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
+ // If tls has been configured; use it.
+ if t.tlsConfig != nil {
+ proto = "tcp-tls"
}
- if rtt < maxTimeout/2 {
- return 2 * rtt
+
+ t.dial <- proto
+ c := <-t.ret
+
+ if c != nil {
+ return c, true, nil
}
- return maxTimeout
+
+ reqTime := time.Now()
+ timeout := t.dialTimeout()
+ if proto == "tcp-tls" {
+ conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return conn, false, err
+ }
+ conn, err := dns.DialTimeout(proto, t.addr, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return conn, false, err
+}
+
+func (p *Proxy) readTimeout() time.Duration {
+ return limitTimeout(&p.avgRtt, minTimeout, maxTimeout)
}
func (p *Proxy) updateRtt(newRtt time.Duration) {
- rtt := time.Duration(atomic.LoadInt64(&p.avgRtt))
- atomic.AddInt64(&p.avgRtt, int64((newRtt-rtt)/rttCount))
+ averageTimeout(&p.avgRtt, newRtt, cumulativeAvgWeight)
}
// Connect selects an upstream, sends the request and waits for a response.
@@ -92,4 +136,4 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, me
return ret, nil
}
-const rttCount = 4
+const cumulativeAvgWeight = 4