diff options
author | 2023-03-24 12:55:51 +0000 | |
---|---|---|
committer | 2023-03-24 08:55:51 -0400 | |
commit | f823825f8a34edb85d5d18cd5d2f6f850adf408e (patch) | |
tree | 79d241ab9b4c7c343d806f4041c8efccbe3f9ca0 /plugin/pkg/proxy/connect.go | |
parent | 47dceabfc6465ba6c5d41472d6602d4ad5c9fb1b (diff) | |
download | coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.gz coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.tar.zst coredns-f823825f8a34edb85d5d18cd5d2f6f850adf408e.zip |
plugin/forward: Allow Proxy to be used outside of forward plugin. (#5951)
* plugin/forward: Move Proxy into pkg/plugin/proxy, to allow forward.Proxy to be used outside of forward plugin.
Signed-off-by: Patrick Downey <patrick.downey@dioadconsulting.com>
Diffstat (limited to 'plugin/pkg/proxy/connect.go')
-rw-r--r-- | plugin/pkg/proxy/connect.go | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go new file mode 100644 index 000000000..29274d92d --- /dev/null +++ b/plugin/pkg/proxy/connect.go @@ -0,0 +1,152 @@ +// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package proxy + +import ( + "context" + "io" + "strconv" + "sync/atomic" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *Transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} + +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *Transport) Dial(proto string) (*persistConn, bool, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" + } + + t.dial <- proto + pc := <-t.ret + + if pc != nil { + ConnCacheHitsCount.WithLabelValues(t.addr, proto).Add(1) + return pc, true, nil + } + ConnCacheMissesCount.WithLabelValues(t.addr, proto).Add(1) + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err +} + +// Connect selects an upstream, sends the request and waits for a response. +func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { + start := time.Now() + + proto := "" + switch { + case opts.ForceTCP: // TCP flag has precedence over UDP flag + proto = "tcp" + case opts.PreferUDP: + proto = "udp" + default: + proto = state.Proto() + } + + pc, cached, err := p.transport.Dial(proto) + if err != nil { + return nil, err + } + + // Set buffer size correctly for this client. + pc.c.UDPSize = uint16(state.Size()) + if pc.c.UDPSize < 512 { + pc.c.UDPSize = 512 + } + + pc.c.SetWriteDeadline(time.Now().Add(maxTimeout)) + // records the origin Id before upstream. + originId := state.Req.Id + state.Req.Id = dns.Id() + defer func() { + state.Req.Id = originId + }() + + if err := pc.c.WriteMsg(state.Req); err != nil { + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + return nil, err + } + + var ret *dns.Msg + pc.c.SetReadDeadline(time.Now().Add(p.readTimeout)) + for { + ret, err = pc.c.ReadMsg() + if err != nil { + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + // recovery the origin Id after upstream. + if ret != nil { + ret.Id = originId + } + return ret, err + } + // drop out-of-order responses + if state.Req.Id == ret.Id { + break + } + } + // recovery the origin Id after upstream. + ret.Id = originId + + p.transport.Yield(pc) + + rc, ok := dns.RcodeToString[ret.Rcode] + if !ok { + rc = strconv.Itoa(ret.Rcode) + } + + RequestCount.WithLabelValues(p.addr).Add(1) + RcodeCount.WithLabelValues(rc, p.addr).Add(1) + RequestDuration.WithLabelValues(p.addr, rc).Observe(time.Since(start).Seconds()) + + return ret, nil +} + +const cumulativeAvgWeight = 4 |