aboutsummaryrefslogtreecommitdiff
path: root/plugin/pkg/proxy/connect.go
diff options
context:
space:
mode:
authorGravatar Pat Downey <patdowney@users.noreply.github.com> 2023-03-24 12:55:51 +0000
committerGravatar GitHub <noreply@github.com> 2023-03-24 08:55:51 -0400
commitf823825f8a34edb85d5d18cd5d2f6f850adf408e (patch)
tree79d241ab9b4c7c343d806f4041c8efccbe3f9ca0 /plugin/pkg/proxy/connect.go
parent47dceabfc6465ba6c5d41472d6602d4ad5c9fb1b (diff)
downloadcoredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.gz
coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.zst
coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.zip
plugin/forward: Allow Proxy to be used outside of forward plugin. (#5951)
* plugin/forward: Move Proxy into pkg/plugin/proxy, to allow forward.Proxy to be used outside of forward plugin. Signed-off-by: Patrick Downey <patrick.downey@dioadconsulting.com>
Diffstat (limited to 'plugin/pkg/proxy/connect.go')
-rw-r--r--plugin/pkg/proxy/connect.go152
1 files changed, 152 insertions, 0 deletions
diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go
new file mode 100644
index 000000000..29274d92d
--- /dev/null
+++ b/plugin/pkg/proxy/connect.go
@@ -0,0 +1,152 @@
+// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
+// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
+// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
+// inband healthchecking.
+package proxy
+
+import (
+ "context"
+ "io"
+ "strconv"
+ "sync/atomic"
+ "time"
+
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+// limitTimeout is a utility function to auto-tune timeout values
+// average observed time is moved towards the last observed delay moderated by a weight
+// next timeout to use will be the double of the computed average, limited by min and max frame.
+func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
+ rt := time.Duration(atomic.LoadInt64(currentAvg))
+ if rt < minValue {
+ return minValue
+ }
+ if rt < maxValue/2 {
+ return 2 * rt
+ }
+ return maxValue
+}
+
+func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
+ dt := time.Duration(atomic.LoadInt64(currentAvg))
+ atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
+}
+
+func (t *Transport) dialTimeout() time.Duration {
+ return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
+}
+
+func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
+ averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
+}
+
+// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
+func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
+ // If tls has been configured; use it.
+ if t.tlsConfig != nil {
+ proto = "tcp-tls"
+ }
+
+ t.dial <- proto
+ pc := <-t.ret
+
+ if pc != nil {
+ ConnCacheHitsCount.WithLabelValues(t.addr, proto).Add(1)
+ return pc, true, nil
+ }
+ ConnCacheMissesCount.WithLabelValues(t.addr, proto).Add(1)
+
+ reqTime := time.Now()
+ timeout := t.dialTimeout()
+ if proto == "tcp-tls" {
+ conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+ }
+ conn, err := dns.DialTimeout(proto, t.addr, timeout)
+ t.updateDialTimeout(time.Since(reqTime))
+ return &persistConn{c: conn}, false, err
+}
+
+// Connect selects an upstream, sends the request and waits for a response.
+func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
+ start := time.Now()
+
+ proto := ""
+ switch {
+ case opts.ForceTCP: // TCP flag has precedence over UDP flag
+ proto = "tcp"
+ case opts.PreferUDP:
+ proto = "udp"
+ default:
+ proto = state.Proto()
+ }
+
+ pc, cached, err := p.transport.Dial(proto)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set buffer size correctly for this client.
+ pc.c.UDPSize = uint16(state.Size())
+ if pc.c.UDPSize < 512 {
+ pc.c.UDPSize = 512
+ }
+
+ pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
+ // records the origin Id before upstream.
+ originId := state.Req.Id
+ state.Req.Id = dns.Id()
+ defer func() {
+ state.Req.Id = originId
+ }()
+
+ if err := pc.c.WriteMsg(state.Req); err != nil {
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ return nil, err
+ }
+
+ var ret *dns.Msg
+ pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
+ for {
+ ret, err = pc.c.ReadMsg()
+ if err != nil {
+ pc.c.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, ErrCachedClosed
+ }
+ // recovery the origin Id after upstream.
+ if ret != nil {
+ ret.Id = originId
+ }
+ return ret, err
+ }
+ // drop out-of-order responses
+ if state.Req.Id == ret.Id {
+ break
+ }
+ }
+ // recovery the origin Id after upstream.
+ ret.Id = originId
+
+ p.transport.Yield(pc)
+
+ rc, ok := dns.RcodeToString[ret.Rcode]
+ if !ok {
+ rc = strconv.Itoa(ret.Rcode)
+ }
+
+ RequestCount.WithLabelValues(p.addr).Add(1)
+ RcodeCount.WithLabelValues(rc, p.addr).Add(1)
+ RequestDuration.WithLabelValues(p.addr, rc).Observe(time.Since(start).Seconds())
+
+ return ret, nil
+}
+
+const cumulativeAvgWeight = 4