aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--plugin/forward/connect.go4
-rw-r--r--plugin/forward/forward_test.go4
-rw-r--r--plugin/forward/health.go47
-rw-r--r--plugin/forward/health_test.go10
-rw-r--r--plugin/forward/lookup.go2
-rw-r--r--plugin/forward/lookup_test.go2
-rw-r--r--plugin/forward/persistent.go2
-rw-r--r--plugin/forward/persistent_test.go8
-rw-r--r--plugin/forward/proxy.go58
-rw-r--r--plugin/forward/proxy_test.go4
-rw-r--r--plugin/forward/setup.go2
-rw-r--r--plugin/forward/setup_test.go4
-rw-r--r--plugin/forward/truncated_test.go6
13 files changed, 77 insertions, 76 deletions
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
index 4a0a7141e..439caf932 100644
--- a/plugin/forward/connect.go
+++ b/plugin/forward/connect.go
@@ -91,7 +91,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
proto = state.Proto()
}
- conn, cached, err := p.Dial(proto)
+ conn, cached, err := p.transport.Dial(proto)
if err != nil {
return nil, err
}
@@ -125,7 +125,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
p.updateRtt(time.Since(reqTime))
- p.Yield(conn)
+ p.transport.Yield(conn)
rc, ok := dns.RcodeToString[ret.Rcode]
if !ok {
diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go
index 96f5fa0ce..82844f811 100644
--- a/plugin/forward/forward_test.go
+++ b/plugin/forward/forward_test.go
@@ -19,7 +19,7 @@ func TestForward(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* not TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@@ -51,7 +51,7 @@ func TestForwardRefused(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.SetProxy(p)
defer f.Close()
diff --git a/plugin/forward/health.go b/plugin/forward/health.go
index 03322e92e..4d3278f6d 100644
--- a/plugin/forward/health.go
+++ b/plugin/forward/health.go
@@ -1,17 +1,48 @@
package forward
import (
+ "crypto/tls"
"sync/atomic"
+ "time"
"github.com/miekg/dns"
)
+// HealthChecker checks the upstream health.
+type HealthChecker interface {
+ Check(*Proxy) error
+ SetTLSConfig(*tls.Config)
+}
+
+// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
+type dnsHc struct{ c *dns.Client }
+
+// NewHealthChecker returns a new HealthChecker based on protocol.
+func NewHealthChecker(protocol int) HealthChecker {
+ switch protocol {
+ case DNS, TLS:
+ c := new(dns.Client)
+ c.Net = "udp"
+ c.ReadTimeout = 1 * time.Second
+ c.WriteTimeout = 1 * time.Second
+
+ return &dnsHc{c: c}
+ }
+
+ return nil
+}
+
+func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
+ h.c.Net = "tcp-tls"
+ h.c.TLSConfig = cfg
+}
+
// For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty
// replies are considered fails, basically anything else constitutes a healthy upstream.
// Check is used as the up.Func in the up.Probe.
-func (p *Proxy) Check() error {
- err := p.send()
+func (h *dnsHc) Check(p *Proxy) error {
+ err := h.send(p.addr)
if err != nil {
HealthcheckFailureCount.WithLabelValues(p.addr).Add(1)
atomic.AddUint32(&p.fails, 1)
@@ -22,14 +53,14 @@ func (p *Proxy) Check() error {
return nil
}
-func (p *Proxy) send() error {
- hcping := new(dns.Msg)
- hcping.SetQuestion(".", dns.TypeNS)
+func (h *dnsHc) send(addr string) error {
+ ping := new(dns.Msg)
+ ping.SetQuestion(".", dns.TypeNS)
- m, _, err := p.client.Exchange(hcping, p.addr)
- // If we got a header, we're alright, basically only care about I/O errors 'n stuff
+ m, _, err := h.c.Exchange(ping, addr)
+ // If we got a header, we're alright, basically only care about I/O errors 'n stuff.
if err != nil && m != nil {
- // Silly check, something sane came back
+ // Silly check, something sane came back.
if m.Response || m.Opcode == dns.OpcodeQuery {
err = nil
}
diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go
index 0588f1454..75d57f285 100644
--- a/plugin/forward/health_test.go
+++ b/plugin/forward/health_test.go
@@ -25,7 +25,7 @@ func TestHealth(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@@ -65,7 +65,7 @@ func TestHealthTimeout(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@@ -109,7 +109,7 @@ func TestHealthFailTwice(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@@ -132,7 +132,7 @@ func TestHealthMaxFails(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.maxfails = 2
f.SetProxy(p)
@@ -163,7 +163,7 @@ func TestHealthNoMaxFails(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.maxfails = 0
f.SetProxy(p)
diff --git a/plugin/forward/lookup.go b/plugin/forward/lookup.go
index 96eceab84..94114647c 100644
--- a/plugin/forward/lookup.go
+++ b/plugin/forward/lookup.go
@@ -81,7 +81,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M
func NewLookup(addr []string) *Forward {
f := New()
for i := range addr {
- p := NewProxy(addr[i], nil)
+ p := NewProxy(addr[i], DNS)
f.SetProxy(p)
}
return f
diff --git a/plugin/forward/lookup_test.go b/plugin/forward/lookup_test.go
index e37a0c5d7..1968ef979 100644
--- a/plugin/forward/lookup_test.go
+++ b/plugin/forward/lookup_test.go
@@ -19,7 +19,7 @@ func TestLookup(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.SetProxy(p)
defer f.Close()
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go
index 52bd24918..4da1514fe 100644
--- a/plugin/forward/persistent.go
+++ b/plugin/forward/persistent.go
@@ -29,7 +29,7 @@ type transport struct {
stop chan bool
}
-func newTransport(addr string, tlsConfig *tls.Config) *transport {
+func newTransport(addr string) *transport {
t := &transport{
avgDialTime: int64(defaultDialTimeout / 2),
conns: make(map[string][]*persistConn),
diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go
index e046cf4de..271a80c0b 100644
--- a/plugin/forward/persistent_test.go
+++ b/plugin/forward/persistent_test.go
@@ -17,7 +17,7 @@ func TestCached(t *testing.T) {
})
defer s.Close()
- tr := newTransport(s.Addr, nil /* no TLS */)
+ tr := newTransport(s.Addr)
tr.Start()
defer tr.Stop()
@@ -56,7 +56,7 @@ func TestCleanupByTimer(t *testing.T) {
})
defer s.Close()
- tr := newTransport(s.Addr, nil /* no TLS */)
+ tr := newTransport(s.Addr)
tr.SetExpire(100 * time.Millisecond)
tr.Start()
defer tr.Stop()
@@ -90,7 +90,7 @@ func TestPartialCleanup(t *testing.T) {
})
defer s.Close()
- tr := newTransport(s.Addr, nil /* no TLS */)
+ tr := newTransport(s.Addr)
tr.SetExpire(100 * time.Millisecond)
tr.Start()
defer tr.Stop()
@@ -138,7 +138,7 @@ func TestCleanupAll(t *testing.T) {
})
defer s.Close()
- tr := newTransport(s.Addr, nil /* no TLS */)
+ tr := newTransport(s.Addr)
c1, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout)
c2, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout)
diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go
index 91d7c38b1..ac74bf0f8 100644
--- a/plugin/forward/proxy.go
+++ b/plugin/forward/proxy.go
@@ -7,8 +7,6 @@ import (
"time"
"github.com/coredns/coredns/plugin/pkg/up"
-
- "github.com/miekg/dns"
)
// Proxy defines an upstream host.
@@ -16,69 +14,46 @@ type Proxy struct {
avgRtt int64
fails uint32
- addr string
- client *dns.Client
+ addr string
// Connection caching
expire time.Duration
transport *transport
// health checking
- probe *up.Probe
+ probe *up.Probe
+ health HealthChecker
}
// NewProxy returns a new proxy.
-func NewProxy(addr string, tlsConfig *tls.Config) *Proxy {
+func NewProxy(addr string, protocol int) *Proxy {
p := &Proxy{
addr: addr,
fails: 0,
probe: up.New(),
- transport: newTransport(addr, tlsConfig),
+ transport: newTransport(addr),
avgRtt: int64(maxTimeout / 2),
}
- p.client = dnsClient(tlsConfig)
+ p.health = NewHealthChecker(protocol)
runtime.SetFinalizer(p, (*Proxy).finalizer)
return p
}
-// Addr returns the address to forward to.
-func (p *Proxy) Addr() (addr string) { return p.addr }
-
-// dnsClient returns a client used for health checking.
-func dnsClient(tlsConfig *tls.Config) *dns.Client {
- c := new(dns.Client)
- c.Net = "udp"
- // TODO(miek): this should be half of hcDuration?
- c.ReadTimeout = 1 * time.Second
- c.WriteTimeout = 1 * time.Second
-
- if tlsConfig != nil {
- c.Net = "tcp-tls"
- c.TLSConfig = tlsConfig
- }
- return c
-}
-
// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client.
func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
p.transport.SetTLSConfig(cfg)
- p.client = dnsClient(cfg)
+ p.health.SetTLSConfig(cfg)
}
-// IsTLS returns true if proxy uses tls.
-func (p *Proxy) IsTLS() bool { return p.transport.tlsConfig != nil }
-
// SetExpire sets the expire duration in the lower p.transport.
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
-// Dial connects to the host in p with the configured transport.
-func (p *Proxy) Dial(proto string) (*dns.Conn, bool, error) { return p.transport.Dial(proto) }
-
-// Yield returns the connection to the pool.
-func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }
-
// Healthcheck kicks of a round of health checks for this proxy.
-func (p *Proxy) Healthcheck() { p.probe.Do(p.Check) }
+func (p *Proxy) Healthcheck() {
+ p.probe.Do(func() error {
+ return p.health.Check(p)
+ })
+}
// Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
func (p *Proxy) Down(maxfails uint32) bool {
@@ -91,13 +66,8 @@ func (p *Proxy) Down(maxfails uint32) bool {
}
// close stops the health checking goroutine.
-func (p *Proxy) close() {
- p.probe.Stop()
-}
-
-func (p *Proxy) finalizer() {
- p.transport.Stop()
-}
+func (p *Proxy) close() { p.probe.Stop() }
+func (p *Proxy) finalizer() { p.transport.Stop() }
// start starts the proxy's healthchecking.
func (p *Proxy) start(duration time.Duration) {
diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go
index d68d1def2..d7af25aa0 100644
--- a/plugin/forward/proxy_test.go
+++ b/plugin/forward/proxy_test.go
@@ -26,7 +26,7 @@ func TestProxyClose(t *testing.T) {
ctx := context.TODO()
for i := 0; i < 100; i++ {
- p := NewProxy(s.Addr, nil)
+ p := NewProxy(s.Addr, DNS)
p.start(hcInterval)
go func() { p.Connect(ctx, state, options{}) }()
@@ -95,7 +95,7 @@ func TestProxyTLSFail(t *testing.T) {
}
func TestProtocolSelection(t *testing.T) {
- p := NewProxy("bad_address", nil)
+ p := NewProxy("bad_address", DNS)
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}
diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go
index 152ba36c5..ee48fdaf6 100644
--- a/plugin/forward/setup.go
+++ b/plugin/forward/setup.go
@@ -124,7 +124,7 @@ func parseForward(c *caddy.Controller) (*Forward, error) {
// We can't set tlsConfig here, because we haven't parsed it yet.
// We set it below at the end of parseBlock, use nil now.
- p := NewProxy(h, nil /* no TLS */)
+ p := NewProxy(h, protocols[i])
f.proxies = append(f.proxies, p)
}
diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go
index a8140d410..c72cd1106 100644
--- a/plugin/forward/setup_test.go
+++ b/plugin/forward/setup_test.go
@@ -113,8 +113,8 @@ func TestSetupTLS(t *testing.T) {
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName)
}
- if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].client.TLSConfig.ServerName {
- t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].client.TLSConfig.ServerName)
+ if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName {
+ t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName)
}
}
}
diff --git a/plugin/forward/truncated_test.go b/plugin/forward/truncated_test.go
index 1c9e92a07..b7ff47c14 100644
--- a/plugin/forward/truncated_test.go
+++ b/plugin/forward/truncated_test.go
@@ -34,7 +34,7 @@ func TestLookupTruncated(t *testing.T) {
})
defer s.Close()
- p := NewProxy(s.Addr, nil /* no TLS */)
+ p := NewProxy(s.Addr, DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@@ -88,9 +88,9 @@ func TestForwardTruncated(t *testing.T) {
f := New()
- p1 := NewProxy(s.Addr, nil /* no TLS */)
+ p1 := NewProxy(s.Addr, DNS)
f.SetProxy(p1)
- p2 := NewProxy(s.Addr, nil /* no TLS */)
+ p2 := NewProxy(s.Addr, DNS)
f.SetProxy(p2)
defer f.Close()