aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Ruslan Drozhdzh <30860269+rdrozhdzh@users.noreply.github.com> 2018-04-06 15:41:48 +0300
committerGravatar Miek Gieben <miek@miek.nl> 2018-04-06 13:41:48 +0100
commite46ee9d9cc197d74dac0cb6e8432f83d4f43d1a6 (patch)
tree66ac38903d67d57f0a23552f0be7cea3bd3bdd93
parent848a5d7c7909afbc381a0708dc4893c28a1df61c (diff)
downloadcoredns-e46ee9d9cc197d74dac0cb6e8432f83d4f43d1a6.tar.gz
coredns-e46ee9d9cc197d74dac0cb6e8432f83d4f43d1a6.tar.zst
coredns-e46ee9d9cc197d74dac0cb6e8432f83d4f43d1a6.zip
plugin/forward: retry on cached tcp connection closed by peer (#1655)
* plugin/forward: retry on cached tcp connection closed by peer * fix linter warnings * fixed unit test * modify error message
-rw-r--r--plugin/forward/connect.go9
-rw-r--r--plugin/forward/forward.go6
-rw-r--r--plugin/forward/persistent.go17
-rw-r--r--plugin/forward/persistent_test.go26
-rw-r--r--plugin/forward/proxy.go2
5 files changed, 40 insertions, 20 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
index 5967c396c..6f9897550 100644
--- a/plugin/forward/connect.go
+++ b/plugin/forward/connect.go
@@ -5,6 +5,7 @@
package forward
import (
+ "io"
"strconv"
"time"
@@ -22,7 +23,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
proto = "tcp"
}
- conn, err := p.Dial(proto)
+ conn, cached, err := p.Dial(proto)
if err != nil {
return nil, err
}
@@ -36,6 +37,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
conn.SetWriteDeadline(time.Now().Add(timeout))
if err := conn.WriteMsg(state.Req); err != nil {
conn.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, errCachedClosed
+ }
return nil, err
}
@@ -43,6 +47,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
ret, err := conn.ReadMsg()
if err != nil {
conn.Close() // not giving it back
+ if err == io.EOF && cached {
+ return nil, errCachedClosed
+ }
return nil, err
}
diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go
index 4c842f49e..6d06f79f2 100644
--- a/plugin/forward/forward.go
+++ b/plugin/forward/forward.go
@@ -7,7 +7,6 @@ package forward
import (
"crypto/tls"
"errors"
- "io"
"time"
"github.com/coredns/coredns/plugin"
@@ -92,11 +91,9 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
ret *dns.Msg
err error
)
- stop := false
for {
ret, err = proxy.connect(ctx, state, f.forceTCP, true)
- if err != nil && err == io.EOF && !stop { // Remote side closed conn, can only happen with TCP.
- stop = true
+ if err != nil && err == errCachedClosed { // Remote side closed conn, can only happen with TCP.
continue
}
break
@@ -176,6 +173,7 @@ var (
errInvalidDomain = errors.New("invalid domain for forward")
errNoHealthy = errors.New("no healthy proxies")
errNoForward = errors.New("no forwarder defined")
+ errCachedClosed = errors.New("cached connection was closed by peer")
)
// policy tells forward what policy for selecting upstream it uses.
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go
index 7bf083b49..6ea4d0371 100644
--- a/plugin/forward/persistent.go
+++ b/plugin/forward/persistent.go
@@ -16,8 +16,9 @@ type persistConn struct {
// connErr is used to communicate the connection manager.
type connErr struct {
- c *dns.Conn
- err error
+ c *dns.Conn
+ err error
+ cached bool
}
// transport hold the persistent cache.
@@ -86,7 +87,7 @@ Wait:
if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
t.conns[proto] = t.conns[proto][i+1:]
- t.ret <- connErr{pc.c, nil}
+ t.ret <- connErr{pc.c, nil, true}
continue Wait
}
// This conn has expired. Close it.
@@ -100,12 +101,12 @@ Wait:
go func() {
if proto != "tcp-tls" {
c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
- t.ret <- connErr{c, err}
+ t.ret <- connErr{c, err, false}
return
}
c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
- t.ret <- connErr{c, err}
+ t.ret <- connErr{c, err, false}
}()
case conn := <-t.yield:
@@ -139,7 +140,7 @@ Wait:
}
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
-func (t *transport) Dial(proto string) (*dns.Conn, error) {
+func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"
@@ -147,12 +148,12 @@ func (t *transport) Dial(proto string) (*dns.Conn, error) {
t.dial <- proto
c := <-t.ret
- return c.c, c.err
+ return c.c, c.cached, c.err
}
// Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) {
- t.yield <- connErr{c, nil}
+ t.yield <- connErr{c, nil, false}
}
// Stop stops the transport's connection manager.
diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go
index 5fa491a01..f4f476afa 100644
--- a/plugin/forward/persistent_test.go
+++ b/plugin/forward/persistent_test.go
@@ -19,9 +19,13 @@ func TestPersistent(t *testing.T) {
tr := newTransport(s.Addr, nil /* no TLS */)
defer tr.Stop()
- c1, _ := tr.Dial("udp")
- c2, _ := tr.Dial("udp")
- c3, _ := tr.Dial("udp")
+ c1, cache1, _ := tr.Dial("udp")
+ c2, cache2, _ := tr.Dial("udp")
+ c3, cache3, _ := tr.Dial("udp")
+
+ if cache1 || cache2 || cache3 {
+ t.Errorf("Expected non-cached connection")
+ }
tr.Yield(c1)
tr.Yield(c2)
@@ -31,13 +35,23 @@ func TestPersistent(t *testing.T) {
t.Errorf("Expected cache size to be 3, got %d", x)
}
- tr.Dial("udp")
+ c4, cache4, _ := tr.Dial("udp")
if x := tr.Len(); x != 2 {
t.Errorf("Expected cache size to be 2, got %d", x)
}
- tr.Dial("udp")
+ c5, cache5, _ := tr.Dial("udp")
if x := tr.Len(); x != 1 {
- t.Errorf("Expected cache size to be 2, got %d", x)
+ t.Errorf("Expected cache size to be 1, got %d", x)
+ }
+
+ if cache4 == false || cache5 == false {
+ t.Errorf("Expected cached connection")
+ }
+ tr.Yield(c4)
+ tr.Yield(c5)
+
+ if x := tr.Len(); x != 3 {
+ t.Errorf("Expected cache size to be 3, got %d", x)
}
}
diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go
index 30bab52d1..02d3512cb 100644
--- a/plugin/forward/proxy.go
+++ b/plugin/forward/proxy.go
@@ -58,7 +58,7 @@ func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) }
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
// Dial connects to the host in p with the configured transport.
-func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) }
+func (p *Proxy) Dial(proto string) (*dns.Conn, bool, error) { return p.transport.Dial(proto) }
// Yield returns the connection to the pool.
func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }