diff options
-rw-r--r-- | plugin/forward/health.go | 5 | ||||
-rw-r--r-- | plugin/forward/health_test.go | 40 | ||||
-rw-r--r-- | plugin/forward/setup.go | 4 |
3 files changed, 49 insertions, 0 deletions
diff --git a/plugin/forward/health.go b/plugin/forward/health.go index fcd3df200..f4cfab834 100644 --- a/plugin/forward/health.go +++ b/plugin/forward/health.go @@ -16,6 +16,7 @@ type HealthChecker interface { SetTLSConfig(*tls.Config) SetRecursionDesired(bool) GetRecursionDesired() bool + SetTCPTransport() } // dnsHc is a health checker for a DNS endpoint (DNS, and DoT). @@ -57,6 +58,10 @@ func (h *dnsHc) GetRecursionDesired() bool { return h.recursionDesired } +func (h *dnsHc) SetTCPTransport() { + h.c.Net = "tcp" +} + // For HC we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty // replies are considered fails, basically anything else constitutes a healthy upstream. diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go index 88a96e803..2d65f4353 100644 --- a/plugin/forward/health_test.go +++ b/plugin/forward/health_test.go @@ -52,6 +52,46 @@ func TestHealth(t *testing.T) { } } +func TestHealthTCP(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + p.health.SetTCPTransport() + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{TCP: true}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + func TestHealthNoRecursion(t *testing.T) { hcReadTimeout = 10 * time.Millisecond readTimeout = 10 * time.Millisecond diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 8bd1e1ff4..010dfa754 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -144,6 +144,10 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { } f.proxies[i].SetExpire(f.expire) f.proxies[i].health.SetRecursionDesired(f.opts.hcRecursionDesired) + // when TLS is used, checks are set to tcp-tls + if f.opts.forceTCP && transports[i] != transport.TLS { + f.proxies[i].health.SetTCPTransport() + } } return f, nil |