diff options
author | 2018-02-05 22:00:47 +0000 | |
---|---|---|
committer | 2018-02-05 22:00:47 +0000 | |
commit | 5b844b5017f004fffa83157041e8ffd3ac085c92 (patch) | |
tree | cbf86bb06cd42f720037a0e473ce2d1cba4036af /plugin/forward/setup.go | |
parent | fb1cafe5fa54935361a5cc9a7e3308a738225126 (diff) | |
download | coredns-5b844b5017f004fffa83157041e8ffd3ac085c92.tar.gz coredns-5b844b5017f004fffa83157041e8ffd3ac085c92.tar.zst coredns-5b844b5017f004fffa83157041e8ffd3ac085c92.zip |
plugin/forward: add it (#1447)
* plugin/forward: add it
This moves coredns/forward into CoreDNS. Fixes as a few bugs, adds a
policy option and more tests to the plugin.
Update the documentation, test IPv6 address and add persistent tests.
* Always use random policy when spraying
* include scrub fix here as well
* use correct var name
* Code review
* go vet
* Move logging to metrcs
* Small readme updates
* Fix readme
Diffstat (limited to 'plugin/forward/setup.go')
-rw-r--r-- | plugin/forward/setup.go | 262 |
1 files changed, 262 insertions, 0 deletions
diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go new file mode 100644 index 000000000..bed20f0c7 --- /dev/null +++ b/plugin/forward/setup.go @@ -0,0 +1,262 @@ +package forward + +import ( + "fmt" + "net" + "strconv" + "time" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + pkgtls "github.com/coredns/coredns/plugin/pkg/tls" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("forward", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + f, err := parseForward(c) + if err != nil { + return plugin.Error("foward", err) + } + if f.Len() > max { + return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len())) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + f.Next = next + return f + }) + + c.OnStartup(func() error { + once.Do(func() { + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(RequestCount) + x.MustRegister(RcodeCount) + x.MustRegister(RequestDuration) + x.MustRegister(HealthcheckFailureCount) + x.MustRegister(SocketGauge) + } + }) + return f.OnStartup() + }) + + c.OnShutdown(func() error { + return f.OnShutdown() + }) + + return nil +} + +// OnStartup starts a goroutines for all proxies. +func (f *Forward) OnStartup() (err error) { + if f.hcInterval == 0 { + for _, p := range f.proxies { + p.host.fails = 0 + } + return nil + } + + for _, p := range f.proxies { + go p.healthCheck() + } + return nil +} + +// OnShutdown stops all configured proxies. +func (f *Forward) OnShutdown() error { + if f.hcInterval == 0 { + return nil + } + + for _, p := range f.proxies { + p.close() + } + return nil +} + +// Close is a synonym for OnShutdown(). +func (f *Forward) Close() { + f.OnShutdown() +} + +func parseForward(c *caddy.Controller) (*Forward, error) { + f := New() + + protocols := map[int]int{} + + for c.Next() { + if !c.Args(&f.from) { + return f, c.ArgErr() + } + f.from = plugin.Host(f.from).Normalize() + + to := c.RemainingArgs() + if len(to) == 0 { + return f, c.ArgErr() + } + + // A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies. + protocols = make(map[int]int) + for i := range to { + protocols[i], to[i] = protocol(to[i]) + } + + // If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make + // any sense anymore... For now: lets don't care. + toHosts, err := dnsutil.ParseHostPortOrFile(to...) + if err != nil { + return f, err + } + + for i, h := range toHosts { + // Double check the port, if e.g. is 53 and the transport is TLS make it 853. + // This can be somewhat annoying because you *can't* have TLS on port 53 then. + switch protocols[i] { + case TLS: + h1, p, err := net.SplitHostPort(h) + if err != nil { + break + } + + // This is more of a bug in // dnsutil.ParseHostPortOrFile that defaults to + // 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence + // Fix the port number here, back to what the user intended. + if p == "53" { + h = net.JoinHostPort(h1, "853") + } + } + + // We can't set tlsConfig here, because we haven't parsed it yet. + // We set it below at the end of parseBlock. + p := NewProxy(h) + f.proxies = append(f.proxies, p) + } + + for c.NextBlock() { + if err := parseBlock(c, f); err != nil { + return f, err + } + } + } + + if f.tlsServerName != "" { + f.tlsConfig.ServerName = f.tlsServerName + } + for i := range f.proxies { + // Only set this for proxies that need it. + if protocols[i] == TLS { + f.proxies[i].SetTLSConfig(f.tlsConfig) + } + f.proxies[i].SetExpire(f.expire) + } + return f, nil +} + +func parseBlock(c *caddy.Controller, f *Forward) error { + switch c.Val() { + case "except": + ignore := c.RemainingArgs() + if len(ignore) == 0 { + return c.ArgErr() + } + for i := 0; i < len(ignore); i++ { + ignore[i] = plugin.Host(ignore[i]).Normalize() + } + f.ignored = ignore + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + if n < 0 { + return fmt.Errorf("max_fails can't be negative: %d", n) + } + f.maxfails = uint32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("health_check can't be negative: %d", dur) + } + f.hcInterval = dur + for i := range f.proxies { + f.proxies[i].hcInterval = dur + } + case "force_tcp": + if c.NextArg() { + return c.ArgErr() + } + f.forceTCP = true + for i := range f.proxies { + f.proxies[i].forceTCP = true + } + case "tls": + args := c.RemainingArgs() + if len(args) != 3 { + return c.ArgErr() + } + + tlsConfig, err := pkgtls.NewTLSConfig(args[0], args[1], args[2]) + if err != nil { + return err + } + f.tlsConfig = tlsConfig + case "tls_servername": + if !c.NextArg() { + return c.ArgErr() + } + f.tlsServerName = c.Val() + case "expire": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("expire can't be negative: %s", dur) + } + f.expire = dur + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + switch x := c.Val(); x { + case "random": + f.p = &random{} + case "round_robin": + f.p = &roundRobin{} + default: + return c.Errf("unknown policy '%s'", x) + } + + default: + return c.Errf("unknown property '%s'", c.Val()) + } + + return nil +} + +const max = 15 // Maximum number of upstreams. |