aboutsummaryrefslogtreecommitdiff
path: root/plugin
diff options
context:
space:
mode:
authorGravatar Miek Gieben <miek@miek.nl> 2018-02-05 22:00:47 +0000
committerGravatar GitHub <noreply@github.com> 2018-02-05 22:00:47 +0000
commit5b844b5017f004fffa83157041e8ffd3ac085c92 (patch)
treecbf86bb06cd42f720037a0e473ce2d1cba4036af /plugin
parentfb1cafe5fa54935361a5cc9a7e3308a738225126 (diff)
downloadcoredns-5b844b5017f004fffa83157041e8ffd3ac085c92.tar.gz
coredns-5b844b5017f004fffa83157041e8ffd3ac085c92.tar.zst
coredns-5b844b5017f004fffa83157041e8ffd3ac085c92.zip
plugin/forward: add it (#1447)
* plugin/forward: add it This moves coredns/forward into CoreDNS. Fixes as a few bugs, adds a policy option and more tests to the plugin. Update the documentation, test IPv6 address and add persistent tests. * Always use random policy when spraying * include scrub fix here as well * use correct var name * Code review * go vet * Move logging to metrcs * Small readme updates * Fix readme
Diffstat (limited to 'plugin')
-rw-r--r--plugin/forward/README.md156
-rw-r--r--plugin/forward/connect.go66
-rw-r--r--plugin/forward/forward.go154
-rw-r--r--plugin/forward/forward_test.go42
-rw-r--r--plugin/forward/health.go67
-rw-r--r--plugin/forward/host.go44
-rw-r--r--plugin/forward/lookup.go78
-rw-r--r--plugin/forward/lookup_test.go41
-rw-r--r--plugin/forward/metrics.go52
-rw-r--r--plugin/forward/persistent.go148
-rw-r--r--plugin/forward/persistent_test.go44
-rw-r--r--plugin/forward/policy.go55
-rw-r--r--plugin/forward/protocol.go30
-rw-r--r--plugin/forward/proxy.go77
-rw-r--r--plugin/forward/setup.go262
-rw-r--r--plugin/forward/setup_policy_test.go46
-rw-r--r--plugin/forward/setup_test.go68
17 files changed, 1430 insertions, 0 deletions
diff --git a/plugin/forward/README.md b/plugin/forward/README.md
new file mode 100644
index 000000000..bbef305db
--- /dev/null
+++ b/plugin/forward/README.md
@@ -0,0 +1,156 @@
+# forward
+
+## Name
+
+*forward* facilitates proxying DNS messages to upstream resolvers.
+
+## Description
+
+The *forward* plugin is generally faster (~30+%) than *proxy* as it re-uses already opened sockets
+to the upstreams. It supports UDP, TCP and DNS-over-TLS and uses inband health checking that is
+enabled by default.
+When *all* upstreams are down it assumes healtchecking as a mechanism has failed and will try to
+connect to a random upstream (which may or may not work).
+
+## Syntax
+
+In its most basic form, a simple forwarder uses this syntax:
+
+~~~
+forward FROM TO...
+~~~
+
+* **FROM** is the base domain to match for the request to be forwarded.
+* **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify
+ a protocol, `tls://9.9.9.9` or `dns://` for plain DNS. The number of upstreams is limited to 15.
+
+The health checks are done every *0.5s*. After *two* failed checks the upstream is considered
+unhealthy. The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any
+response that is not an error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The
+health check uses the same protocol as specific in the **TO**. On startup each upstream is marked
+unhealthy until it passes a health check. A 0 duration will disable any health checks.
+
+Multiple upstreams are randomized (default policy) on first use. When a healthy proxy returns an
+error during the exchange the next upstream in the list is tried.
+
+Extra knobs are available with an expanded syntax:
+
+~~~
+forward FROM TO... {
+ except IGNORED_NAMES...
+ force_tcp
+ health_check DURATION
+ expire DURATION
+ max_fails INTEGER
+ tls CERT KEY CA
+ tls_servername NAME
+ policy random|round_robin
+}
+~~~
+
+* **FROM** and **TO...** as above.
+* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding.
+ Requests that match none of these names will be passed through.
+* `force_tcp`, use TCP even when the request comes in over UDP.
+* `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s.
+ A value of 0 disables the health checks completely.
+* `max_fails` is the number of subsequent failed health checks that are needed before considering
+ a backend to be down. If 0, the backend will never be marked as down. Default is 2.
+* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s.
+* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the
+ system's configuration will be used.
+* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9
+ needs this to be set to `dns.quad9.net`.
+* `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
+
+The upstream selection is done via random (default policy) selection. If the socket for this client
+isn't known *forward* will randomly choose one. If this turns out to be unhealthy, the next one is
+tried. If *all* hosts are down, we assume health checking is broken and select a *random* upstream to
+try.
+
+Also note the TLS config is "global" for the whole forwarding proxy if you need a different
+`tls-name` for different upstreams you're out of luck.
+
+## Metrics
+
+If monitoring is enabled (via the *prometheus* directive) then the following metric are exported:
+
+* `coredns_forward_request_duration_seconds{to}` - duration per upstream interaction.
+* `coredns_forward_request_count_total{to}` - query count per upstream.
+* `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream.
+* `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream.
+* `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy,
+ and we are randomly spraying to a target.
+* `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream.
+
+Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by
+the incoming query ("tcp" or "udp"), and family the transport family ("1" for IPv4, and "2" for
+IPv6).
+
+## Examples
+
+Proxy all requests within example.org. to a nameserver running on a different port:
+
+~~~ corefile
+example.org {
+ forward . 127.0.0.1:9005
+}
+~~~
+
+Load balance all requests between three resolvers, one of which has a IPv6 address.
+
+~~~ corefile
+. {
+ forward . 10.0.0.10:53 10.0.0.11:1053 [2003::1]:53
+}
+~~~
+
+Forward everything except requests to `example.org`
+
+~~~ corefile
+. {
+ forward . 10.0.0.10:1234 {
+ except example.org
+ }
+}
+~~~
+
+Proxy everything except `example.org` using the host's `resolv.conf`'s nameservers:
+
+~~~ corefile
+. {
+ forward . /etc/resolv.conf {
+ except example.org
+ }
+}
+~~~
+
+Forward to a IPv6 host:
+
+~~~ corefile
+. {
+ forward . [::1]:1053
+}
+~~~
+
+Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30
+seconds.
+
+~~~ corefile
+. {
+ forward . tls://9.9.9.9 {
+ tls_servername dns.quad9.net
+ health_check 5s
+ }
+ cache 30
+}
+~~~
+
+## Bugs
+
+The TLS config is global for the whole forwarding proxy if you need a different `tls-name` for
+different upstreams you're out of luck.
+
+## Also See
+
+[RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS.
diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
new file mode 100644
index 000000000..cdad29ed1
--- /dev/null
+++ b/plugin/forward/connect.go
@@ -0,0 +1,66 @@
+// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
+// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
+// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
+// inband healthchecking.
+package forward
+
+import (
+ "strconv"
+ "time"
+
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
+ start := time.Now()
+
+ proto := state.Proto()
+ if forceTCP {
+ proto = "tcp"
+ }
+ if p.host.tlsConfig != nil {
+ proto = "tcp-tls"
+ }
+
+ conn, err := p.Dial(proto)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set buffer size correctly for this client.
+ conn.UDPSize = uint16(state.Size())
+ if conn.UDPSize < 512 {
+ conn.UDPSize = 512
+ }
+
+ conn.SetWriteDeadline(time.Now().Add(timeout))
+ if err := conn.WriteMsg(state.Req); err != nil {
+ conn.Close() // not giving it back
+ return nil, err
+ }
+
+ conn.SetReadDeadline(time.Now().Add(timeout))
+ ret, err := conn.ReadMsg()
+ if err != nil {
+ conn.Close() // not giving it back
+ return nil, err
+ }
+
+ p.Yield(conn)
+
+ if metric {
+ rc, ok := dns.RcodeToString[ret.Rcode]
+ if !ok {
+ rc = strconv.Itoa(ret.Rcode)
+ }
+
+ RequestCount.WithLabelValues(p.host.addr).Add(1)
+ RcodeCount.WithLabelValues(rc, p.host.addr).Add(1)
+ RequestDuration.WithLabelValues(p.host.addr).Observe(time.Since(start).Seconds())
+ }
+
+ return ret, nil
+}
diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go
new file mode 100644
index 000000000..35885008e
--- /dev/null
+++ b/plugin/forward/forward.go
@@ -0,0 +1,154 @@
+// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
+// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
+// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
+// inband healthchecking.
+package forward
+
+import (
+ "crypto/tls"
+ "errors"
+ "time"
+
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+ ot "github.com/opentracing/opentracing-go"
+ "golang.org/x/net/context"
+)
+
+// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list
+// of proxies each representing one upstream proxy.
+type Forward struct {
+ proxies []*Proxy
+ p Policy
+
+ from string
+ ignored []string
+
+ tlsConfig *tls.Config
+ tlsServerName string
+ maxfails uint32
+ expire time.Duration
+
+ forceTCP bool // also here for testing
+ hcInterval time.Duration // also here for testing
+
+ Next plugin.Handler
+}
+
+// New returns a new Forward.
+func New() *Forward {
+ f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: hcDuration, p: new(random)}
+ return f
+}
+
+// SetProxy appends p to the proxy list and starts healthchecking.
+func (f *Forward) SetProxy(p *Proxy) {
+ f.proxies = append(f.proxies, p)
+ go p.healthCheck()
+}
+
+// Len returns the number of configured proxies.
+func (f *Forward) Len() int { return len(f.proxies) }
+
+// Name implements plugin.Handler.
+func (f *Forward) Name() string { return "forward" }
+
+// ServeDNS implements plugin.Handler.
+func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+
+ state := request.Request{W: w, Req: r}
+ if !f.match(state) {
+ return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
+ }
+
+ fails := 0
+ var span, child ot.Span
+ span = ot.SpanFromContext(ctx)
+
+ for _, proxy := range f.list() {
+ if proxy.Down(f.maxfails) {
+ fails++
+ if fails < len(f.proxies) {
+ continue
+ }
+ // All upstream proxies are dead, assume healtcheck is completely broken and randomly
+ // select an upstream to connect to.
+ r := new(random)
+ proxy = r.List(f.proxies)[0]
+
+ HealthcheckBrokenCount.Add(1)
+ }
+
+ if span != nil {
+ child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context()))
+ ctx = ot.ContextWithSpan(ctx, child)
+ }
+
+ ret, err := proxy.connect(ctx, state, f.forceTCP, true)
+
+ if child != nil {
+ child.Finish()
+ }
+
+ if err != nil {
+ if fails < len(f.proxies) {
+ continue
+ }
+ break
+ }
+
+ ret.Compress = true
+ // When using force_tcp the upstream can send a message that is too big for
+ // the udp buffer, hence we need to truncate the message to at least make it
+ // fit the udp buffer.
+ ret, _ = state.Scrub(ret)
+
+ w.WriteMsg(ret)
+
+ return 0, nil
+ }
+
+ return dns.RcodeServerFailure, errNoHealthy
+}
+
+func (f *Forward) match(state request.Request) bool {
+ from := f.from
+
+ if !plugin.Name(from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) {
+ return false
+ }
+
+ return true
+}
+
+func (f *Forward) isAllowedDomain(name string) bool {
+ if dns.Name(name) == dns.Name(f.from) {
+ return true
+ }
+
+ for _, ignore := range f.ignored {
+ if plugin.Name(ignore).Matches(name) {
+ return false
+ }
+ }
+ return true
+}
+
+// List returns a set of proxies to be used for this client depending on the policy in f.
+func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) }
+
+var (
+ errInvalidDomain = errors.New("invalid domain for proxy")
+ errNoHealthy = errors.New("no healthy proxies")
+ errNoForward = errors.New("no forwarder defined")
+)
+
+// policy tells forward what policy for selecting upstream it uses.
+type policy int
+
+const (
+ randomPolicy policy = iota
+ roundRobinPolicy
+)
diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go
new file mode 100644
index 000000000..d467a0efa
--- /dev/null
+++ b/plugin/forward/forward_test.go
@@ -0,0 +1,42 @@
+package forward
+
+import (
+ "testing"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/test"
+ "github.com/coredns/coredns/request"
+ "github.com/miekg/dns"
+)
+
+func TestForward(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr)
+ f := New()
+ f.SetProxy(p)
+ defer f.Close()
+
+ state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
+ state.Req.SetQuestion("example.org.", dns.TypeA)
+ resp, err := f.Forward(state)
+ if err != nil {
+ t.Fatal("Expected to receive reply, but didn't")
+ }
+ // expect answer section with A record in it
+ if len(resp.Answer) == 0 {
+ t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp)
+ }
+ if resp.Answer[0].Header().Rrtype != dns.TypeA {
+ t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype)
+ }
+ if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" {
+ t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String())
+ }
+}
diff --git a/plugin/forward/health.go b/plugin/forward/health.go
new file mode 100644
index 000000000..e277f30a6
--- /dev/null
+++ b/plugin/forward/health.go
@@ -0,0 +1,67 @@
+package forward
+
+import (
+ "log"
+ "sync/atomic"
+
+ "github.com/miekg/dns"
+)
+
+// For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty
+// replies are considered fails, basically anything else constitutes a healthy upstream.
+
+func (h *host) Check() {
+ h.Lock()
+
+ if h.checking {
+ h.Unlock()
+ return
+ }
+
+ h.checking = true
+ h.Unlock()
+
+ err := h.send()
+ if err != nil {
+ log.Printf("[INFO] healtheck of %s failed with %s", h.addr, err)
+
+ HealthcheckFailureCount.WithLabelValues(h.addr).Add(1)
+
+ atomic.AddUint32(&h.fails, 1)
+ } else {
+ atomic.StoreUint32(&h.fails, 0)
+ }
+
+ h.Lock()
+ h.checking = false
+ h.Unlock()
+
+ return
+}
+
+func (h *host) send() error {
+ hcping := new(dns.Msg)
+ hcping.SetQuestion(".", dns.TypeNS)
+ hcping.RecursionDesired = false
+
+ m, _, err := h.client.Exchange(hcping, h.addr)
+ // If we got a header, we're alright, basically only care about I/O errors 'n stuff
+ if err != nil && m != nil {
+ // Silly check, something sane came back
+ if m.Response || m.Opcode == dns.OpcodeQuery {
+ err = nil
+ }
+ }
+
+ return err
+}
+
+// down returns true is this host has more than maxfails fails.
+func (h *host) down(maxfails uint32) bool {
+ if maxfails == 0 {
+ return false
+ }
+
+ fails := atomic.LoadUint32(&h.fails)
+ return fails > maxfails
+}
diff --git a/plugin/forward/host.go b/plugin/forward/host.go
new file mode 100644
index 000000000..48d6c7d6e
--- /dev/null
+++ b/plugin/forward/host.go
@@ -0,0 +1,44 @@
+package forward
+
+import (
+ "crypto/tls"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+type host struct {
+ addr string
+ client *dns.Client
+
+ tlsConfig *tls.Config
+ expire time.Duration
+
+ fails uint32
+ sync.RWMutex
+ checking bool
+}
+
+// newHost returns a new host, the fails are set to 1, i.e.
+// the first healthcheck must succeed before we use this host.
+func newHost(addr string) *host {
+ return &host{addr: addr, fails: 1, expire: defaultExpire}
+}
+
+// setClient sets and configures the dns.Client in host.
+func (h *host) SetClient() {
+ c := new(dns.Client)
+ c.Net = "udp"
+ c.ReadTimeout = 2 * time.Second
+ c.WriteTimeout = 2 * time.Second
+
+ if h.tlsConfig != nil {
+ c.Net = "tcp-tls"
+ c.TLSConfig = h.tlsConfig
+ }
+
+ h.client = c
+}
+
+const defaultExpire = 10 * time.Second
diff --git a/plugin/forward/lookup.go b/plugin/forward/lookup.go
new file mode 100644
index 000000000..47c4319cf
--- /dev/null
+++ b/plugin/forward/lookup.go
@@ -0,0 +1,78 @@
+// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
+// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
+// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
+// inband healthchecking.
+package forward
+
+import (
+ "crypto/tls"
+ "log"
+ "time"
+
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+// Forward forward the request in state as-is. Unlike Lookup that adds EDNS0 suffix to the message.
+// Forward may be called with a nil f, an error is returned in that case.
+func (f *Forward) Forward(state request.Request) (*dns.Msg, error) {
+ if f == nil {
+ return nil, errNoForward
+ }
+
+ fails := 0
+ for _, proxy := range f.list() {
+ if proxy.Down(f.maxfails) {
+ fails++
+ if fails < len(f.proxies) {
+ continue
+ }
+ // All upstream proxies are dead, assume healtcheck is complete broken and randomly
+ // select an upstream to connect to.
+ proxy = f.list()[0]
+ log.Printf("[WARNING] All upstreams down, picking random one to connect to %s", proxy.host.addr)
+ }
+
+ ret, err := proxy.connect(context.Background(), state, f.forceTCP, true)
+ if err != nil {
+ log.Printf("[WARNING] Failed to connect to %s: %s", proxy.host.addr, err)
+ if fails < len(f.proxies) {
+ continue
+ }
+ break
+
+ }
+
+ return ret, nil
+ }
+ return nil, errNoHealthy
+}
+
+// Lookup will use name and type to forge a new message and will send that upstream. It will
+// set any EDNS0 options correctly so that downstream will be able to process the reply.
+// Lookup may be called with a nil f, an error is returned in that case.
+func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.Msg, error) {
+ if f == nil {
+ return nil, errNoForward
+ }
+
+ req := new(dns.Msg)
+ req.SetQuestion(name, typ)
+ state.SizeAndDo(req)
+
+ state2 := request.Request{W: state.W, Req: req}
+
+ return f.Forward(state2)
+}
+
+// NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names.
+func NewLookup(addr []string) *Forward {
+ f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: 2 * time.Second}
+ for i := range addr {
+ p := NewProxy(addr[i])
+ f.SetProxy(p)
+ }
+ return f
+}
diff --git a/plugin/forward/lookup_test.go b/plugin/forward/lookup_test.go
new file mode 100644
index 000000000..69c7a1949
--- /dev/null
+++ b/plugin/forward/lookup_test.go
@@ -0,0 +1,41 @@
+package forward
+
+import (
+ "testing"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/test"
+ "github.com/coredns/coredns/request"
+ "github.com/miekg/dns"
+)
+
+func TestLookup(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ p := NewProxy(s.Addr)
+ f := New()
+ f.SetProxy(p)
+ defer f.Close()
+
+ state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
+ resp, err := f.Lookup(state, "example.org.", dns.TypeA)
+ if err != nil {
+ t.Fatal("Expected to receive reply, but didn't")
+ }
+ // expect answer section with A record in it
+ if len(resp.Answer) == 0 {
+ t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp)
+ }
+ if resp.Answer[0].Header().Rrtype != dns.TypeA {
+ t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype)
+ }
+ if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" {
+ t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String())
+ }
+}
diff --git a/plugin/forward/metrics.go b/plugin/forward/metrics.go
new file mode 100644
index 000000000..1e72454e0
--- /dev/null
+++ b/plugin/forward/metrics.go
@@ -0,0 +1,52 @@
+package forward
+
+import (
+ "sync"
+
+ "github.com/coredns/coredns/plugin"
+
+ "github.com/prometheus/client_golang/prometheus"
+)
+
+// Variables declared for monitoring.
+var (
+ RequestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "forward",
+ Name: "request_count_total",
+ Help: "Counter of requests made per upstream.",
+ }, []string{"to"})
+ RcodeCount = prometheus.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "forward",
+ Name: "response_rcode_count_total",
+ Help: "Counter of requests made per upstream.",
+ }, []string{"rcode", "to"})
+ RequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "forward",
+ Name: "request_duration_seconds",
+ Buckets: plugin.TimeBuckets,
+ Help: "Histogram of the time each request took.",
+ }, []string{"to"})
+ HealthcheckFailureCount = prometheus.NewCounterVec(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "forward",
+ Name: "healthcheck_failure_count_total",
+ Help: "Counter of the number of failed healtchecks.",
+ }, []string{"to"})
+ HealthcheckBrokenCount = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "forward",
+ Name: "healthcheck_broken_count_total",
+ Help: "Counter of the number of complete failures of the healtchecks.",
+ })
+ SocketGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
+ Namespace: plugin.Namespace,
+ Subsystem: "forward",
+ Name: "socket_count_total",
+ Help: "Guage of open sockets per upstream.",
+ }, []string{"to"})
+)
+
+var once sync.Once
diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go
new file mode 100644
index 000000000..6a7c4464e
--- /dev/null
+++ b/plugin/forward/persistent.go
@@ -0,0 +1,148 @@
+package forward
+
+import (
+ "net"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// a persistConn hold the dns.Conn and the last used time.
+type persistConn struct {
+ c *dns.Conn
+ used time.Time
+}
+
+// connErr is used to communicate the connection manager.
+type connErr struct {
+ c *dns.Conn
+ err error
+}
+
+// transport hold the persistent cache.
+type transport struct {
+ conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
+ host *host
+
+ dial chan string
+ yield chan connErr
+ ret chan connErr
+
+ // Aid in testing, gets length of cache in data-race safe manner.
+ lenc chan bool
+ lencOut chan int
+
+ stop chan bool
+}
+
+func newTransport(h *host) *transport {
+ t := &transport{
+ conns: make(map[string][]*persistConn),
+ host: h,
+ dial: make(chan string),
+ yield: make(chan connErr),
+ ret: make(chan connErr),
+ stop: make(chan bool),
+ lenc: make(chan bool),
+ lencOut: make(chan int),
+ }
+ go t.connManager()
+ return t
+}
+
+// len returns the number of connection, used for metrics. Can only be safely
+// used inside connManager() because of races.
+func (t *transport) len() int {
+ l := 0
+ for _, conns := range t.conns {
+ l += len(conns)
+ }
+ return l
+}
+
+// Len returns the number of connections in the cache.
+func (t *transport) Len() int {
+ t.lenc <- true
+ l := <-t.lencOut
+ return l
+}
+
+// connManagers manages the persistent connection cache for UDP and TCP.
+func (t *transport) connManager() {
+
+Wait:
+ for {
+ select {
+ case proto := <-t.dial:
+ // Yes O(n), shouldn't put millions in here. We walk all connection until we find the first
+ // one that is usuable.
+ i := 0
+ for i = 0; i < len(t.conns[proto]); i++ {
+ pc := t.conns[proto][i]
+ if time.Since(pc.used) < t.host.expire {
+ // Found one, remove from pool and return this conn.
+ t.conns[proto] = t.conns[proto][i+1:]
+ t.ret <- connErr{pc.c, nil}
+ continue Wait
+ }
+ // This conn has expired. Close it.
+ pc.c.Close()
+ }
+
+ // Not conns were found. Connect to the upstream to create one.
+ t.conns[proto] = t.conns[proto][i:]
+ SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len()))
+
+ go func() {
+ if proto != "tcp-tls" {
+ c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout)
+ t.ret <- connErr{c, err}
+ return
+ }
+
+ c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout)
+ t.ret <- connErr{c, err}
+ }()
+
+ case conn := <-t.yield:
+
+ SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1))
+
+ // no proto here, infer from config and conn
+ if _, ok := conn.c.Conn.(*net.UDPConn); ok {
+ t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()})
+ continue Wait
+ }
+
+ if t.host.tlsConfig == nil {
+ t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
+ continue Wait
+ }
+
+ t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()})
+
+ case <-t.stop:
+ return
+
+ case <-t.lenc:
+ l := 0
+ for _, conns := range t.conns {
+ l += len(conns)
+ }
+ t.lencOut <- l
+ }
+ }
+}
+
+func (t *transport) Dial(proto string) (*dns.Conn, error) {
+ t.dial <- proto
+ c := <-t.ret
+ return c.c, c.err
+}
+
+func (t *transport) Yield(c *dns.Conn) {
+ t.yield <- connErr{c, nil}
+}
+
+// Stop stops the transports.
+func (t *transport) Stop() { t.stop <- true }
diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go
new file mode 100644
index 000000000..5674658e6
--- /dev/null
+++ b/plugin/forward/persistent_test.go
@@ -0,0 +1,44 @@
+package forward
+
+import (
+ "testing"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+
+ "github.com/miekg/dns"
+)
+
+func TestPersistent(t *testing.T) {
+ s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+ ret := new(dns.Msg)
+ ret.SetReply(r)
+ w.WriteMsg(ret)
+ })
+ defer s.Close()
+
+ h := newHost(s.Addr)
+ tr := newTransport(h)
+ defer tr.Stop()
+
+ c1, _ := tr.Dial("udp")
+ c2, _ := tr.Dial("udp")
+ c3, _ := tr.Dial("udp")
+
+ tr.Yield(c1)
+ tr.Yield(c2)
+ tr.Yield(c3)
+
+ if x := tr.Len(); x != 3 {
+ t.Errorf("Expected cache size to be 3, got %d", x)
+ }
+
+ tr.Dial("udp")
+ if x := tr.Len(); x != 2 {
+ t.Errorf("Expected cache size to be 2, got %d", x)
+ }
+
+ tr.Dial("udp")
+ if x := tr.Len(); x != 1 {
+ t.Errorf("Expected cache size to be 2, got %d", x)
+ }
+}
diff --git a/plugin/forward/policy.go b/plugin/forward/policy.go
new file mode 100644
index 000000000..f39a14105
--- /dev/null
+++ b/plugin/forward/policy.go
@@ -0,0 +1,55 @@
+package forward
+
+import (
+ "math/rand"
+ "sync/atomic"
+)
+
+// Policy defines a policy we use for selecting upstreams.
+type Policy interface {
+ List([]*Proxy) []*Proxy
+ String() string
+}
+
+// random is a policy that implements random upstream selection.
+type random struct{}
+
+func (r *random) String() string { return "random" }
+
+func (r *random) List(p []*Proxy) []*Proxy {
+ switch len(p) {
+ case 1:
+ return p
+ case 2:
+ if rand.Int()%2 == 0 {
+ return []*Proxy{p[1], p[0]} // swap
+ }
+ return p
+ }
+
+ perms := rand.Perm(len(p))
+ rnd := make([]*Proxy, len(p))
+
+ for i, p1 := range perms {
+ rnd[i] = p[p1]
+ }
+ return rnd
+}
+
+// roundRobin is a policy that selects hosts based on round robin ordering.
+type roundRobin struct {
+ robin uint32
+}
+
+func (r *roundRobin) String() string { return "round_robin" }
+
+func (r *roundRobin) List(p []*Proxy) []*Proxy {
+ poolLen := uint32(len(p))
+ i := atomic.AddUint32(&r.robin, 1) % poolLen
+
+ robin := []*Proxy{p[i]}
+ robin = append(robin, p[:i]...)
+ robin = append(robin, p[i+1:]...)
+
+ return robin
+}
diff --git a/plugin/forward/protocol.go b/plugin/forward/protocol.go
new file mode 100644
index 000000000..338b60116
--- /dev/null
+++ b/plugin/forward/protocol.go
@@ -0,0 +1,30 @@
+package forward
+
+// Copied from coredns/core/dnsserver/address.go
+
+import (
+ "strings"
+)
+
+// protocol returns the protocol of the string s. The second string returns s
+// with the prefix chopped off.
+func protocol(s string) (int, string) {
+ switch {
+ case strings.HasPrefix(s, _tls+"://"):
+ return TLS, s[len(_tls)+3:]
+ case strings.HasPrefix(s, _dns+"://"):
+ return DNS, s[len(_dns)+3:]
+ }
+ return DNS, s
+}
+
+// Supported protocols.
+const (
+ DNS = iota + 1
+ TLS
+)
+
+const (
+ _dns = "dns"
+ _tls = "tls"
+)
diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go
new file mode 100644
index 000000000..c89490374
--- /dev/null
+++ b/plugin/forward/proxy.go
@@ -0,0 +1,77 @@
+package forward
+
+import (
+ "crypto/tls"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Proxy defines an upstream host.
+type Proxy struct {
+ host *host
+
+ transport *transport
+
+ // copied from Forward.
+ hcInterval time.Duration
+ forceTCP bool
+
+ stop chan bool
+
+ sync.RWMutex
+}
+
+// NewProxy returns a new proxy.
+func NewProxy(addr string) *Proxy {
+ host := newHost(addr)
+
+ p := &Proxy{
+ host: host,
+ hcInterval: hcDuration,
+ stop: make(chan bool),
+ transport: newTransport(host),
+ }
+ return p
+}
+
+// SetTLSConfig sets the TLS config in the lower p.host.
+func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.host.tlsConfig = cfg }
+
+// SetExpire sets the expire duration in the lower p.host.
+func (p *Proxy) SetExpire(expire time.Duration) { p.host.expire = expire }
+
+func (p *Proxy) close() { p.stop <- true }
+
+// Dial connects to the host in p with the configured transport.
+func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) }
+
+// Yield returns the connection to the pool.
+func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }
+
+// Down returns if this proxy is up or down.
+func (p *Proxy) Down(maxfails uint32) bool { return p.host.down(maxfails) }
+
+func (p *Proxy) healthCheck() {
+
+ // stop channel
+ p.host.SetClient()
+
+ p.host.Check()
+ tick := time.NewTicker(p.hcInterval)
+ for {
+ select {
+ case <-tick.C:
+ p.host.Check()
+ case <-p.stop:
+ return
+ }
+ }
+}
+
+const (
+ dialTimeout = 4 * time.Second
+ timeout = 2 * time.Second
+ hcDuration = 500 * time.Millisecond
+)
diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go
new file mode 100644
index 000000000..bed20f0c7
--- /dev/null
+++ b/plugin/forward/setup.go
@@ -0,0 +1,262 @@
+package forward
+
+import (
+ "fmt"
+ "net"
+ "strconv"
+ "time"
+
+ "github.com/coredns/coredns/core/dnsserver"
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/metrics"
+ "github.com/coredns/coredns/plugin/pkg/dnsutil"
+ pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
+
+ "github.com/mholt/caddy"
+)
+
+func init() {
+ caddy.RegisterPlugin("forward", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ f, err := parseForward(c)
+ if err != nil {
+ return plugin.Error("foward", err)
+ }
+ if f.Len() > max {
+ return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len()))
+ }
+
+ dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
+ f.Next = next
+ return f
+ })
+
+ c.OnStartup(func() error {
+ once.Do(func() {
+ m := dnsserver.GetConfig(c).Handler("prometheus")
+ if m == nil {
+ return
+ }
+ if x, ok := m.(*metrics.Metrics); ok {
+ x.MustRegister(RequestCount)
+ x.MustRegister(RcodeCount)
+ x.MustRegister(RequestDuration)
+ x.MustRegister(HealthcheckFailureCount)
+ x.MustRegister(SocketGauge)
+ }
+ })
+ return f.OnStartup()
+ })
+
+ c.OnShutdown(func() error {
+ return f.OnShutdown()
+ })
+
+ return nil
+}
+
+// OnStartup starts a goroutines for all proxies.
+func (f *Forward) OnStartup() (err error) {
+ if f.hcInterval == 0 {
+ for _, p := range f.proxies {
+ p.host.fails = 0
+ }
+ return nil
+ }
+
+ for _, p := range f.proxies {
+ go p.healthCheck()
+ }
+ return nil
+}
+
+// OnShutdown stops all configured proxies.
+func (f *Forward) OnShutdown() error {
+ if f.hcInterval == 0 {
+ return nil
+ }
+
+ for _, p := range f.proxies {
+ p.close()
+ }
+ return nil
+}
+
+// Close is a synonym for OnShutdown().
+func (f *Forward) Close() {
+ f.OnShutdown()
+}
+
+func parseForward(c *caddy.Controller) (*Forward, error) {
+ f := New()
+
+ protocols := map[int]int{}
+
+ for c.Next() {
+ if !c.Args(&f.from) {
+ return f, c.ArgErr()
+ }
+ f.from = plugin.Host(f.from).Normalize()
+
+ to := c.RemainingArgs()
+ if len(to) == 0 {
+ return f, c.ArgErr()
+ }
+
+ // A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
+ protocols = make(map[int]int)
+ for i := range to {
+ protocols[i], to[i] = protocol(to[i])
+ }
+
+ // If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
+ // any sense anymore... For now: lets don't care.
+ toHosts, err := dnsutil.ParseHostPortOrFile(to...)
+ if err != nil {
+ return f, err
+ }
+
+ for i, h := range toHosts {
+ // Double check the port, if e.g. is 53 and the transport is TLS make it 853.
+ // This can be somewhat annoying because you *can't* have TLS on port 53 then.
+ switch protocols[i] {
+ case TLS:
+ h1, p, err := net.SplitHostPort(h)
+ if err != nil {
+ break
+ }
+
+ // This is more of a bug in // dnsutil.ParseHostPortOrFile that defaults to
+ // 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
+ // Fix the port number here, back to what the user intended.
+ if p == "53" {
+ h = net.JoinHostPort(h1, "853")
+ }
+ }
+
+ // We can't set tlsConfig here, because we haven't parsed it yet.
+ // We set it below at the end of parseBlock.
+ p := NewProxy(h)
+ f.proxies = append(f.proxies, p)
+ }
+
+ for c.NextBlock() {
+ if err := parseBlock(c, f); err != nil {
+ return f, err
+ }
+ }
+ }
+
+ if f.tlsServerName != "" {
+ f.tlsConfig.ServerName = f.tlsServerName
+ }
+ for i := range f.proxies {
+ // Only set this for proxies that need it.
+ if protocols[i] == TLS {
+ f.proxies[i].SetTLSConfig(f.tlsConfig)
+ }
+ f.proxies[i].SetExpire(f.expire)
+ }
+ return f, nil
+}
+
+func parseBlock(c *caddy.Controller, f *Forward) error {
+ switch c.Val() {
+ case "except":
+ ignore := c.RemainingArgs()
+ if len(ignore) == 0 {
+ return c.ArgErr()
+ }
+ for i := 0; i < len(ignore); i++ {
+ ignore[i] = plugin.Host(ignore[i]).Normalize()
+ }
+ f.ignored = ignore
+ case "max_fails":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ n, err := strconv.Atoi(c.Val())
+ if err != nil {
+ return err
+ }
+ if n < 0 {
+ return fmt.Errorf("max_fails can't be negative: %d", n)
+ }
+ f.maxfails = uint32(n)
+ case "health_check":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ if dur < 0 {
+ return fmt.Errorf("health_check can't be negative: %d", dur)
+ }
+ f.hcInterval = dur
+ for i := range f.proxies {
+ f.proxies[i].hcInterval = dur
+ }
+ case "force_tcp":
+ if c.NextArg() {
+ return c.ArgErr()
+ }
+ f.forceTCP = true
+ for i := range f.proxies {
+ f.proxies[i].forceTCP = true
+ }
+ case "tls":
+ args := c.RemainingArgs()
+ if len(args) != 3 {
+ return c.ArgErr()
+ }
+
+ tlsConfig, err := pkgtls.NewTLSConfig(args[0], args[1], args[2])
+ if err != nil {
+ return err
+ }
+ f.tlsConfig = tlsConfig
+ case "tls_servername":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ f.tlsServerName = c.Val()
+ case "expire":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ if dur < 0 {
+ return fmt.Errorf("expire can't be negative: %s", dur)
+ }
+ f.expire = dur
+ case "policy":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ switch x := c.Val(); x {
+ case "random":
+ f.p = &random{}
+ case "round_robin":
+ f.p = &roundRobin{}
+ default:
+ return c.Errf("unknown policy '%s'", x)
+ }
+
+ default:
+ return c.Errf("unknown property '%s'", c.Val())
+ }
+
+ return nil
+}
+
+const max = 15 // Maximum number of upstreams.
diff --git a/plugin/forward/setup_policy_test.go b/plugin/forward/setup_policy_test.go
new file mode 100644
index 000000000..8c40b9fdd
--- /dev/null
+++ b/plugin/forward/setup_policy_test.go
@@ -0,0 +1,46 @@
+package forward
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/mholt/caddy"
+)
+
+func TestSetupPolicy(t *testing.T) {
+ tests := []struct {
+ input string
+ shouldErr bool
+ expectedPolicy string
+ expectedErr string
+ }{
+ // positive
+ {"forward . 127.0.0.1 {\npolicy random\n}\n", false, "random", ""},
+ {"forward . 127.0.0.1 {\npolicy round_robin\n}\n", false, "round_robin", ""},
+ // negative
+ {"forward . 127.0.0.1 {\npolicy random2\n}\n", true, "random", "unknown policy"},
+ }
+
+ for i, test := range tests {
+ c := caddy.NewTestController("dns", test.input)
+ f, err := parseForward(c)
+
+ if test.shouldErr && err == nil {
+ t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
+ }
+
+ if err != nil {
+ if !test.shouldErr {
+ t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
+ }
+
+ if !strings.Contains(err.Error(), test.expectedErr) {
+ t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
+ }
+ }
+
+ if !test.shouldErr && f.p.String() != test.expectedPolicy {
+ t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedPolicy, f.p.String())
+ }
+ }
+}
diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go
new file mode 100644
index 000000000..f1776222f
--- /dev/null
+++ b/plugin/forward/setup_test.go
@@ -0,0 +1,68 @@
+package forward
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/mholt/caddy"
+)
+
+func TestSetup(t *testing.T) {
+ tests := []struct {
+ input string
+ shouldErr bool
+ expectedFrom string
+ expectedIgnored []string
+ expectedFails uint32
+ expectedForceTCP bool
+ expectedErr string
+ }{
+ // positive
+ {"forward . 127.0.0.1", false, ".", nil, 2, false, ""},
+ {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, false, ""},
+ {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, false, ""},
+ {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, true, ""},
+ {"forward . 127.0.0.1:53", false, ".", nil, 2, false, ""},
+ {"forward . 127.0.0.1:8080", false, ".", nil, 2, false, ""},
+ {"forward . [::1]:53", false, ".", nil, 2, false, ""},
+ {"forward . [2003::1]:53", false, ".", nil, 2, false, ""},
+ // negative
+ {"forward . a27.0.0.1", true, "", nil, 0, false, "not an IP"},
+ {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, false, "unknown property"},
+ }
+
+ for i, test := range tests {
+ c := caddy.NewTestController("dns", test.input)
+ f, err := parseForward(c)
+
+ if test.shouldErr && err == nil {
+ t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
+ }
+
+ if err != nil {
+ if !test.shouldErr {
+ t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
+ }
+
+ if !strings.Contains(err.Error(), test.expectedErr) {
+ t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
+ }
+ }
+
+ if !test.shouldErr && f.from != test.expectedFrom {
+ t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedFrom, f.from)
+ }
+ if !test.shouldErr && test.expectedIgnored != nil {
+ if !reflect.DeepEqual(f.ignored, test.expectedIgnored) {
+ t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, f.ignored)
+ }
+ }
+ if !test.shouldErr && f.maxfails != test.expectedFails {
+ t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails)
+ }
+ if !test.shouldErr && f.forceTCP != test.expectedForceTCP {
+ t.Errorf("Test %d: expected: %t, got: %t", i, test.expectedForceTCP, f.forceTCP)
+ }
+ }
+}