aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--plugin/forward/connect.go11
-rw-r--r--plugin/forward/forward.go3
-rw-r--r--plugin/forward/persistent.go64
-rw-r--r--plugin/forward/proxy.go26
-rw-r--r--plugin/forward/proxy_test.go51
5 files changed, 37 insertions, 118 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
index 40c9d62ca..5bd55f2ab 100644
--- a/plugin/forward/connect.go
+++ b/plugin/forward/connect.go
@@ -34,16 +34,6 @@ func (p *Proxy) updateRtt(newRtt time.Duration) {
}
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
- atomic.AddInt32(&p.inProgress, 1)
- defer func() {
- if atomic.AddInt32(&p.inProgress, -1) == 0 {
- p.checkStopTransport()
- }
- }()
- if atomic.LoadUint32(&p.state) != running {
- return nil, errStopped
- }
-
start := time.Now()
proto := state.Proto()
@@ -55,6 +45,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
if err != nil {
return nil, err
}
+
// Set buffer size correctly for this client.
conn.UDPSize = uint16(state.Size())
if conn.UDPSize < 512 {
diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go
index 84339d4bd..20d995710 100644
--- a/plugin/forward/forward.go
+++ b/plugin/forward/forward.go
@@ -119,7 +119,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
if err != nil {
// Kick off health check to see if *our* upstream is broken.
- if f.maxfails != 0 && err != errStopped {
+ if f.maxfails != 0 {
proxy.Healthcheck()
}
@@ -185,7 +185,6 @@ var (
errNoHealthy = errors.New("no healthy proxies")
errNoForward = errors.New("no forwarder defined")
errCachedClosed = errors.New("cached connection was closed by peer")
- errStopped = errors.New("proxy has been stopped")
)
// policy tells forward what policy for selecting upstream it uses.
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go
index decac412c..dc03002d3 100644
--- a/plugin/forward/persistent.go
+++ b/plugin/forward/persistent.go
@@ -14,13 +14,6 @@ type persistConn struct {
used time.Time
}
-// connErr is used to communicate the connection manager.
-type connErr struct {
- c *dns.Conn
- err error
- cached bool
-}
-
// transport hold the persistent cache.
type transport struct {
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
@@ -29,8 +22,8 @@ type transport struct {
tlsConfig *tls.Config
dial chan string
- yield chan connErr
- ret chan connErr
+ yield chan *dns.Conn
+ ret chan *dns.Conn
stop chan bool
}
@@ -40,18 +33,11 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport {
expire: defaultExpire,
addr: addr,
dial: make(chan string),
- yield: make(chan connErr),
- ret: make(chan connErr),
+ yield: make(chan *dns.Conn),
+ ret: make(chan *dns.Conn),
stop: make(chan bool),
}
- go func() {
- t.connManager()
- // if connManager returns it has been stopped.
- close(t.stop)
- close(t.yield)
- close(t.dial)
- // close(t.ret) // we can still be dialing and wanting to send back the socket on t.ret
- }()
+ go func() { t.connManager() }()
return t
}
@@ -80,7 +66,7 @@ Wait:
if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
t.conns[proto] = t.conns[proto][i+1:]
- t.ret <- connErr{pc.c, nil, true}
+ t.ret <- pc.c
continue Wait
}
// This conn has expired. Close it.
@@ -91,35 +77,27 @@ Wait:
t.conns[proto] = t.conns[proto][i:]
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
- go func() {
- if proto != "tcp-tls" {
- c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
- t.ret <- connErr{c, err, false}
- return
- }
-
- c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
- t.ret <- connErr{c, err, false}
- }()
+ t.ret <- nil
case conn := <-t.yield:
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
// no proto here, infer from config and conn
- if _, ok := conn.c.Conn.(*net.UDPConn); ok {
- t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()})
+ if _, ok := conn.Conn.(*net.UDPConn); ok {
+ t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
continue Wait
}
if t.tlsConfig == nil {
- t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
+ t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
continue Wait
}
- t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()})
+ t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
case <-t.stop:
+ close(t.ret)
return
}
}
@@ -134,16 +112,24 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
t.dial <- proto
c := <-t.ret
- return c.c, c.cached, c.err
+
+ if c != nil {
+ return c, true, nil
+ }
+
+ if proto == "tcp-tls" {
+ conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
+ return conn, false, err
+ }
+ conn, err := dns.DialTimeout(proto, t.addr, dialTimeout)
+ return conn, false, err
}
// Yield return the connection to transport for reuse.
-func (t *transport) Yield(c *dns.Conn) {
- t.yield <- connErr{c, nil, false}
-}
+func (t *transport) Yield(c *dns.Conn) { t.yield <- c }
// Stop stops the transport's connection manager.
-func (t *transport) Stop() { t.stop <- true }
+func (t *transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport.
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go
index c788f98cc..b6b570149 100644
--- a/plugin/forward/proxy.go
+++ b/plugin/forward/proxy.go
@@ -24,9 +24,6 @@ type Proxy struct {
fails uint32
avgRtt int64
-
- state uint32
- inProgress int32
}
// NewProxy returns a new proxy.
@@ -85,26 +82,15 @@ func (p *Proxy) Down(maxfails uint32) bool {
return fails > maxfails
}
-// close stops the health checking goroutine and connection manager.
+// close stops the health checking goroutine.
func (p *Proxy) close() {
- if atomic.CompareAndSwapUint32(&p.state, running, stopping) {
- p.probe.Stop()
- }
- if atomic.LoadInt32(&p.inProgress) == 0 {
- p.checkStopTransport()
- }
+ p.probe.Stop()
+ p.transport.Stop()
}
// start starts the proxy's healthchecking.
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
-// checkStopTransport checks if stop was requested and stops connection manager
-func (p *Proxy) checkStopTransport() {
- if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) {
- p.transport.Stop()
- }
-}
-
const (
dialTimeout = 4 * time.Second
timeout = 2 * time.Second
@@ -112,9 +98,3 @@ const (
minTimeout = 10 * time.Millisecond
hcDuration = 500 * time.Millisecond
)
-
-const (
- running = iota
- stopping
- stopped
-)
diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go
index acd3d240c..d473d6881 100644
--- a/plugin/forward/proxy_test.go
+++ b/plugin/forward/proxy_test.go
@@ -2,9 +2,7 @@ package forward
import (
"context"
- "runtime"
"testing"
- "time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
@@ -28,50 +26,15 @@ func TestProxyClose(t *testing.T) {
ctx := context.TODO()
for i := 0; i < 100; i++ {
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, nil)
p.start(hcDuration)
- doneCnt := 0
- doneCh := make(chan bool)
- timeCh := time.After(10 * time.Second)
- go func() {
- p.connect(ctx, state, false, false)
- doneCh <- true
- }()
- go func() {
- p.connect(ctx, state, true, false)
- doneCh <- true
- }()
- go func() {
- p.close()
- doneCh <- true
- }()
- go func() {
- p.connect(ctx, state, false, false)
- doneCh <- true
- }()
- go func() {
- p.connect(ctx, state, true, false)
- doneCh <- true
- }()
-
- for doneCnt < 5 {
- select {
- case <-doneCh:
- doneCnt++
- case <-timeCh:
- t.Error("TestProxyClose is running too long, dumping goroutines:")
- buf := make([]byte, 100000)
- stackSize := runtime.Stack(buf, true)
- t.Fatal(string(buf[:stackSize]))
- }
- }
- if p.inProgress != 0 {
- t.Errorf("unexpected query in progress")
- }
- if p.state != stopped {
- t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state)
- }
+ go func() { p.connect(ctx, state, false, false) }()
+ go func() { p.connect(ctx, state, true, false) }()
+ go func() { p.connect(ctx, state, false, false) }()
+ go func() { p.connect(ctx, state, true, false) }()
+
+ p.close()
}
}