aboutsummaryrefslogtreecommitdiff
path: root/plugin
diff options
context:
space:
mode:
Diffstat (limited to 'plugin')
-rw-r--r--plugin/forward/connect.go11
-rw-r--r--plugin/forward/forward.go3
-rw-r--r--plugin/forward/proxy.go26
-rw-r--r--plugin/forward/proxy_test.go65
4 files changed, 100 insertions, 5 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
index 0a66f2752..6ea7913e5 100644
--- a/plugin/forward/connect.go
+++ b/plugin/forward/connect.go
@@ -35,6 +35,16 @@ func (p *Proxy) updateRtt(newRtt time.Duration) {
}
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
+ atomic.AddInt32(&p.inProgress, 1)
+ defer func() {
+ if atomic.AddInt32(&p.inProgress, -1) == 0 {
+ p.checkStopTransport()
+ }
+ }()
+ if atomic.LoadUint32(&p.state) != running {
+ return nil, errStopped
+ }
+
start := time.Now()
proto := state.Proto()
@@ -46,7 +56,6 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
if err != nil {
return nil, err
}
-
// Set buffer size correctly for this client.
conn.UDPSize = uint16(state.Size())
if conn.UDPSize < 512 {
diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go
index 153c5ab38..213b30f8b 100644
--- a/plugin/forward/forward.go
+++ b/plugin/forward/forward.go
@@ -120,7 +120,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
if err != nil {
// Kick off health check to see if *our* upstream is broken.
- if f.maxfails != 0 {
+ if f.maxfails != 0 && err != errStopped {
proxy.Healthcheck()
}
@@ -186,6 +186,7 @@ var (
errNoHealthy = errors.New("no healthy proxies")
errNoForward = errors.New("no forwarder defined")
errCachedClosed = errors.New("cached connection was closed by peer")
+ errStopped = errors.New("proxy has been stopped")
)
// policy tells forward what policy for selecting upstream it uses.
diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go
index 3271e7dd9..8454b296d 100644
--- a/plugin/forward/proxy.go
+++ b/plugin/forward/proxy.go
@@ -24,6 +24,9 @@ type Proxy struct {
fails uint32
avgRtt int64
+
+ state uint32
+ inProgress int32
}
// NewProxy returns a new proxy.
@@ -79,15 +82,26 @@ func (p *Proxy) Down(maxfails uint32) bool {
return fails > maxfails
}
-// close stops the health checking goroutine.
+// close stops the health checking goroutine and connection manager.
func (p *Proxy) close() {
- p.probe.Stop()
- p.transport.Stop()
+ if atomic.CompareAndSwapUint32(&p.state, running, stopping) {
+ p.probe.Stop()
+ }
+ if atomic.LoadInt32(&p.inProgress) == 0 {
+ p.checkStopTransport()
+ }
}
// start starts the proxy's healthchecking.
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
+// checkStopTransport checks if stop was requested and stops connection manager
+func (p *Proxy) checkStopTransport() {
+ if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) {
+ p.transport.Stop()
+ }
+}
+
const (
dialTimeout = 4 * time.Second
timeout = 2 * time.Second
@@ -95,3 +109,9 @@ const (
minTimeout = 10 * time.Millisecond
hcDuration = 500 * time.Millisecond
)
+
+const (
+ running = iota
+ stopping
+ stopped
+)
diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go
new file mode 100644
index 000000000..8c53f3150
--- /dev/null
+++ b/plugin/forward/proxy_test.go
@@ -0,0 +1,65 @@
+package forward
+
+import (
+ "context"
+ "sync"
+ "testing"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/test"
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+func TestProxyClose(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ req := new(dns.Msg)
+ req.SetQuestion("example.org.", dns.TypeA)
+ state := request.Request{W: &test.ResponseWriter{}, Req: req}
+ ctx := context.TODO()
+
+ repeatCnt := 1000
+ for repeatCnt > 0 {
+ repeatCnt--
+ p := NewProxy(s.Addr, nil /* no TLS */)
+ p.start(hcDuration)
+
+ var wg sync.WaitGroup
+ wg.Add(5)
+ go func() {
+ p.connect(ctx, state, false, false)
+ wg.Done()
+ }()
+ go func() {
+ p.connect(ctx, state, true, false)
+ wg.Done()
+ }()
+ go func() {
+ p.close()
+ wg.Done()
+ }()
+ go func() {
+ p.connect(ctx, state, false, false)
+ wg.Done()
+ }()
+ go func() {
+ p.connect(ctx, state, true, false)
+ wg.Done()
+ }()
+ wg.Wait()
+
+ if p.inProgress != 0 {
+ t.Errorf("unexpected query in progress")
+ }
+ if p.state != stopped {
+ t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state)
+ }
+ }
+}