aboutsummaryrefslogtreecommitdiff
path: root/plugin/forward/connect.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/forward/connect.go')
-rw-r--r--plugin/forward/connect.go34
1 files changed, 17 insertions, 17 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
index 8fde2224b..9ac1afe16 100644
--- a/plugin/forward/connect.go
+++ b/plugin/forward/connect.go
@@ -44,17 +44,17 @@ func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
}
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
-func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) {
+func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"
}
t.dial <- proto
- c := <-t.ret
+ pc := <-t.ret
- if c != nil {
- return c, true, nil
+ if pc != nil {
+ return pc, true, nil
}
reqTime := time.Now()
@@ -62,11 +62,11 @@ func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) {
if proto == "tcp-tls" {
conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
t.updateDialTimeout(time.Since(reqTime))
- return conn, false, err
+ return &persistConn{c: conn}, false, err
}
conn, err := dns.DialTimeout(proto, t.addr, timeout)
t.updateDialTimeout(time.Since(reqTime))
- return conn, false, err
+ return &persistConn{c: conn}, false, err
}
// Connect selects an upstream, sends the request and waits for a response.
@@ -83,20 +83,20 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
proto = state.Proto()
}
- conn, cached, err := p.transport.Dial(proto)
+ pc, cached, err := p.transport.Dial(proto)
if err != nil {
return nil, err
}
// Set buffer size correctly for this client.
- conn.UDPSize = uint16(state.Size())
- if conn.UDPSize < 512 {
- conn.UDPSize = 512
+ pc.c.UDPSize = uint16(state.Size())
+ if pc.c.UDPSize < 512 {
+ pc.c.UDPSize = 512
}
- conn.SetWriteDeadline(time.Now().Add(maxTimeout))
- if err := conn.WriteMsg(state.Req); err != nil {
- conn.Close() // not giving it back
+ pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
+ if err := pc.c.WriteMsg(state.Req); err != nil {
+ pc.c.Close() // not giving it back
if err == io.EOF && cached {
return nil, ErrCachedClosed
}
@@ -104,11 +104,11 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
}
var ret *dns.Msg
- conn.SetReadDeadline(time.Now().Add(readTimeout))
+ pc.c.SetReadDeadline(time.Now().Add(readTimeout))
for {
- ret, err = conn.ReadMsg()
+ ret, err = pc.c.ReadMsg()
if err != nil {
- conn.Close() // not giving it back
+ pc.c.Close() // not giving it back
if err == io.EOF && cached {
return nil, ErrCachedClosed
}
@@ -120,7 +120,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
}
}
- p.transport.Yield(conn)
+ p.transport.Yield(pc)
rc, ok := dns.RcodeToString[ret.Rcode]
if !ok {