aboutsummaryrefslogtreecommitdiff
path: root/plugin/forward/setup.go
diff options
context:
space:
mode:
authorGravatar Miek Gieben <miek@miek.nl> 2018-02-05 22:00:47 +0000
committerGravatar GitHub <noreply@github.com> 2018-02-05 22:00:47 +0000
commit5b844b5017f004fffa83157041e8ffd3ac085c92 (patch)
treecbf86bb06cd42f720037a0e473ce2d1cba4036af /plugin/forward/setup.go
parentfb1cafe5fa54935361a5cc9a7e3308a738225126 (diff)
downloadcoredns-5b844b5017f004fffa83157041e8ffd3ac085c92.tar.gz
coredns-5b844b5017f004fffa83157041e8ffd3ac085c92.tar.zst
coredns-5b844b5017f004fffa83157041e8ffd3ac085c92.zip
plugin/forward: add it (#1447)
* plugin/forward: add it This moves coredns/forward into CoreDNS. Fixes as a few bugs, adds a policy option and more tests to the plugin. Update the documentation, test IPv6 address and add persistent tests. * Always use random policy when spraying * include scrub fix here as well * use correct var name * Code review * go vet * Move logging to metrcs * Small readme updates * Fix readme
Diffstat (limited to 'plugin/forward/setup.go')
-rw-r--r--plugin/forward/setup.go262
1 files changed, 262 insertions, 0 deletions
diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go
new file mode 100644
index 000000000..bed20f0c7
--- /dev/null
+++ b/plugin/forward/setup.go
@@ -0,0 +1,262 @@
+package forward
+
+import (
+ "fmt"
+ "net"
+ "strconv"
+ "time"
+
+ "github.com/coredns/coredns/core/dnsserver"
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/metrics"
+ "github.com/coredns/coredns/plugin/pkg/dnsutil"
+ pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
+
+ "github.com/mholt/caddy"
+)
+
+func init() {
+ caddy.RegisterPlugin("forward", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ f, err := parseForward(c)
+ if err != nil {
+ return plugin.Error("foward", err)
+ }
+ if f.Len() > max {
+ return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len()))
+ }
+
+ dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
+ f.Next = next
+ return f
+ })
+
+ c.OnStartup(func() error {
+ once.Do(func() {
+ m := dnsserver.GetConfig(c).Handler("prometheus")
+ if m == nil {
+ return
+ }
+ if x, ok := m.(*metrics.Metrics); ok {
+ x.MustRegister(RequestCount)
+ x.MustRegister(RcodeCount)
+ x.MustRegister(RequestDuration)
+ x.MustRegister(HealthcheckFailureCount)
+ x.MustRegister(SocketGauge)
+ }
+ })
+ return f.OnStartup()
+ })
+
+ c.OnShutdown(func() error {
+ return f.OnShutdown()
+ })
+
+ return nil
+}
+
+// OnStartup starts a goroutines for all proxies.
+func (f *Forward) OnStartup() (err error) {
+ if f.hcInterval == 0 {
+ for _, p := range f.proxies {
+ p.host.fails = 0
+ }
+ return nil
+ }
+
+ for _, p := range f.proxies {
+ go p.healthCheck()
+ }
+ return nil
+}
+
+// OnShutdown stops all configured proxies.
+func (f *Forward) OnShutdown() error {
+ if f.hcInterval == 0 {
+ return nil
+ }
+
+ for _, p := range f.proxies {
+ p.close()
+ }
+ return nil
+}
+
+// Close is a synonym for OnShutdown().
+func (f *Forward) Close() {
+ f.OnShutdown()
+}
+
+func parseForward(c *caddy.Controller) (*Forward, error) {
+ f := New()
+
+ protocols := map[int]int{}
+
+ for c.Next() {
+ if !c.Args(&f.from) {
+ return f, c.ArgErr()
+ }
+ f.from = plugin.Host(f.from).Normalize()
+
+ to := c.RemainingArgs()
+ if len(to) == 0 {
+ return f, c.ArgErr()
+ }
+
+ // A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
+ protocols = make(map[int]int)
+ for i := range to {
+ protocols[i], to[i] = protocol(to[i])
+ }
+
+ // If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
+ // any sense anymore... For now: lets don't care.
+ toHosts, err := dnsutil.ParseHostPortOrFile(to...)
+ if err != nil {
+ return f, err
+ }
+
+ for i, h := range toHosts {
+ // Double check the port, if e.g. is 53 and the transport is TLS make it 853.
+ // This can be somewhat annoying because you *can't* have TLS on port 53 then.
+ switch protocols[i] {
+ case TLS:
+ h1, p, err := net.SplitHostPort(h)
+ if err != nil {
+ break
+ }
+
+ // This is more of a bug in // dnsutil.ParseHostPortOrFile that defaults to
+ // 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
+ // Fix the port number here, back to what the user intended.
+ if p == "53" {
+ h = net.JoinHostPort(h1, "853")
+ }
+ }
+
+ // We can't set tlsConfig here, because we haven't parsed it yet.
+ // We set it below at the end of parseBlock.
+ p := NewProxy(h)
+ f.proxies = append(f.proxies, p)
+ }
+
+ for c.NextBlock() {
+ if err := parseBlock(c, f); err != nil {
+ return f, err
+ }
+ }
+ }
+
+ if f.tlsServerName != "" {
+ f.tlsConfig.ServerName = f.tlsServerName
+ }
+ for i := range f.proxies {
+ // Only set this for proxies that need it.
+ if protocols[i] == TLS {
+ f.proxies[i].SetTLSConfig(f.tlsConfig)
+ }
+ f.proxies[i].SetExpire(f.expire)
+ }
+ return f, nil
+}
+
+func parseBlock(c *caddy.Controller, f *Forward) error {
+ switch c.Val() {
+ case "except":
+ ignore := c.RemainingArgs()
+ if len(ignore) == 0 {
+ return c.ArgErr()
+ }
+ for i := 0; i < len(ignore); i++ {
+ ignore[i] = plugin.Host(ignore[i]).Normalize()
+ }
+ f.ignored = ignore
+ case "max_fails":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ n, err := strconv.Atoi(c.Val())
+ if err != nil {
+ return err
+ }
+ if n < 0 {
+ return fmt.Errorf("max_fails can't be negative: %d", n)
+ }
+ f.maxfails = uint32(n)
+ case "health_check":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ if dur < 0 {
+ return fmt.Errorf("health_check can't be negative: %d", dur)
+ }
+ f.hcInterval = dur
+ for i := range f.proxies {
+ f.proxies[i].hcInterval = dur
+ }
+ case "force_tcp":
+ if c.NextArg() {
+ return c.ArgErr()
+ }
+ f.forceTCP = true
+ for i := range f.proxies {
+ f.proxies[i].forceTCP = true
+ }
+ case "tls":
+ args := c.RemainingArgs()
+ if len(args) != 3 {
+ return c.ArgErr()
+ }
+
+ tlsConfig, err := pkgtls.NewTLSConfig(args[0], args[1], args[2])
+ if err != nil {
+ return err
+ }
+ f.tlsConfig = tlsConfig
+ case "tls_servername":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ f.tlsServerName = c.Val()
+ case "expire":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ if dur < 0 {
+ return fmt.Errorf("expire can't be negative: %s", dur)
+ }
+ f.expire = dur
+ case "policy":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ switch x := c.Val(); x {
+ case "random":
+ f.p = &random{}
+ case "round_robin":
+ f.p = &roundRobin{}
+ default:
+ return c.Errf("unknown policy '%s'", x)
+ }
+
+ default:
+ return c.Errf("unknown property '%s'", c.Val())
+ }
+
+ return nil
+}
+
+const max = 15 // Maximum number of upstreams.