aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--plugin/pkg/proxy/health.go2
-rw-r--r--plugin/pkg/proxy/proxy.go10
-rw-r--r--plugin/pkg/proxy/proxy_test.go31
3 files changed, 42 insertions, 1 deletions
diff --git a/plugin/pkg/proxy/health.go b/plugin/pkg/proxy/health.go
index e87104a13..a7e99560d 100644
--- a/plugin/pkg/proxy/health.go
+++ b/plugin/pkg/proxy/health.go
@@ -105,7 +105,7 @@ func (h *dnsHc) Check(p *Proxy) error {
err := h.send(p.addr)
if err != nil {
HealthcheckFailureCount.WithLabelValues(p.addr).Add(1)
- atomic.AddUint32(&p.fails, 1)
+ p.incrementFails()
return err
}
diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go
index be521fe05..414c34240 100644
--- a/plugin/pkg/proxy/proxy.go
+++ b/plugin/pkg/proxy/proxy.go
@@ -93,6 +93,16 @@ func (p *Proxy) SetReadTimeout(duration time.Duration) {
p.readTimeout = duration
}
+// incrementFails increments the number of fails safely.
+func (p *Proxy) incrementFails() {
+ curVal := atomic.LoadUint32(&p.fails)
+ if curVal > curVal+1 {
+ // overflow occurred, do not update the counter again
+ return
+ }
+ atomic.AddUint32(&p.fails, 1)
+}
+
const (
maxTimeout = 2 * time.Second
)
diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go
index 274e9679d..17125ea68 100644
--- a/plugin/pkg/proxy/proxy_test.go
+++ b/plugin/pkg/proxy/proxy_test.go
@@ -3,6 +3,7 @@ package proxy
import (
"context"
"crypto/tls"
+ "math"
"testing"
"time"
@@ -97,3 +98,33 @@ func TestProtocolSelection(t *testing.T) {
}
}
}
+
+func TestProxyIncrementFails(t *testing.T) {
+ var testCases = []struct {
+ name string
+ fails uint32
+ expectFails uint32
+ }{
+ {
+ name: "increment fails counter overflows",
+ fails: math.MaxUint32,
+ expectFails: math.MaxUint32,
+ },
+ {
+ name: "increment fails counter",
+ fails: 0,
+ expectFails: 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ p := NewProxy("bad_address", transport.DNS)
+ p.fails = tc.fails
+ p.incrementFails()
+ if p.fails != tc.expectFails {
+ t.Errorf("Expected fails to be %d, got %d", tc.expectFails, p.fails)
+ }
+ })
+ }
+}