diff options
Diffstat (limited to 'plugin/forward/connect.go')
-rw-r--r-- | plugin/forward/connect.go | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 8fde2224b..9ac1afe16 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -44,17 +44,17 @@ func (t *Transport) updateDialTimeout(newDialTime time.Duration) { } // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. -func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) { +func (t *Transport) Dial(proto string) (*persistConn, bool, error) { // If tls has been configured; use it. if t.tlsConfig != nil { proto = "tcp-tls" } t.dial <- proto - c := <-t.ret + pc := <-t.ret - if c != nil { - return c, true, nil + if pc != nil { + return pc, true, nil } reqTime := time.Now() @@ -62,11 +62,11 @@ func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) { if proto == "tcp-tls" { conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) t.updateDialTimeout(time.Since(reqTime)) - return conn, false, err + return &persistConn{c: conn}, false, err } conn, err := dns.DialTimeout(proto, t.addr, timeout) t.updateDialTimeout(time.Since(reqTime)) - return conn, false, err + return &persistConn{c: conn}, false, err } // Connect selects an upstream, sends the request and waits for a response. @@ -83,20 +83,20 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options proto = state.Proto() } - conn, cached, err := p.transport.Dial(proto) + pc, cached, err := p.transport.Dial(proto) if err != nil { return nil, err } // Set buffer size correctly for this client. - conn.UDPSize = uint16(state.Size()) - if conn.UDPSize < 512 { - conn.UDPSize = 512 + pc.c.UDPSize = uint16(state.Size()) + if pc.c.UDPSize < 512 { + pc.c.UDPSize = 512 } - conn.SetWriteDeadline(time.Now().Add(maxTimeout)) - if err := conn.WriteMsg(state.Req); err != nil { - conn.Close() // not giving it back + pc.c.SetWriteDeadline(time.Now().Add(maxTimeout)) + if err := pc.c.WriteMsg(state.Req); err != nil { + pc.c.Close() // not giving it back if err == io.EOF && cached { return nil, ErrCachedClosed } @@ -104,11 +104,11 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options } var ret *dns.Msg - conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.c.SetReadDeadline(time.Now().Add(readTimeout)) for { - ret, err = conn.ReadMsg() + ret, err = pc.c.ReadMsg() if err != nil { - conn.Close() // not giving it back + pc.c.Close() // not giving it back if err == io.EOF && cached { return nil, ErrCachedClosed } @@ -120,7 +120,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options } } - p.transport.Yield(conn) + p.transport.Yield(pc) rc, ok := dns.RcodeToString[ret.Rcode] if !ok { |