diff options
Diffstat (limited to 'plugin/forward/setup.go')
-rw-r--r-- | plugin/forward/setup.go | 116 |
1 files changed, 25 insertions, 91 deletions
diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 1f88daba5..baf80f12b 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -1,8 +1,6 @@ package forward import ( - "crypto/tls" - "errors" "fmt" "strconv" "time" @@ -11,9 +9,7 @@ import ( "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/dnstap" - "github.com/coredns/coredns/plugin/pkg/parse" pkgtls "github.com/coredns/coredns/plugin/pkg/tls" - "github.com/coredns/coredns/plugin/pkg/transport" ) func init() { plugin.Register("forward", setup) } @@ -87,90 +83,46 @@ func parseForward(c *caddy.Controller) (*Forward, error) { } func parseStanza(c *caddy.Controller) (*Forward, error) { - f := New() + cfg := ForwardConfig{} - if !c.Args(&f.from) { - return f, c.ArgErr() + if !c.Args(&cfg.From) { + return nil, c.ArgErr() } - origFrom := f.from - zones := plugin.Host(f.from).NormalizeExact() - f.from = zones[0] // there can only be one here, won't work with non-octet reverse - if len(zones) > 1 { - log.Warningf("Unsupported CIDR notation: '%s' expands to multiple zones. Using only '%s'.", origFrom, f.from) - } - - to := c.RemainingArgs() - if len(to) == 0 { - return f, c.ArgErr() - } - - toHosts, err := parse.HostPortOrFile(to...) - if err != nil { - return f, err - } - - transports := make([]string, len(toHosts)) - allowedTrans := map[string]bool{"dns": true, "tls": true} - for i, host := range toHosts { - trans, h := parse.Transport(host) - - if !allowedTrans[trans] { - return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) - } - p := NewProxy(h, trans) - f.proxies = append(f.proxies, p) - transports[i] = trans + cfg.To = c.RemainingArgs() + if len(cfg.To) == 0 { + return nil, c.ArgErr() } for c.NextBlock() { - if err := parseBlock(c, f); err != nil { - return f, err - } - } - - if f.tlsServerName != "" { - f.tlsConfig.ServerName = f.tlsServerName - } - - // Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake - // in upcoming connections to the same TLS server. - f.tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(len(f.proxies)) - - for i := range f.proxies { - // Only set this for proxies that need it. - if transports[i] == transport.TLS { - f.proxies[i].SetTLSConfig(f.tlsConfig) + if err := parseBlock(c, &cfg); err != nil { + return nil, err } - f.proxies[i].SetExpire(f.expire) - f.proxies[i].health.SetRecursionDesired(f.opts.hcRecursionDesired) } - return f, nil + return NewWithConfig(cfg) } -func parseBlock(c *caddy.Controller, f *Forward) error { +func parseBlock(c *caddy.Controller, cfg *ForwardConfig) error { switch c.Val() { case "except": - ignore := c.RemainingArgs() - if len(ignore) == 0 { + cfg.Except = c.RemainingArgs() + if len(cfg.Except) == 0 { return c.ArgErr() } - for i := 0; i < len(ignore); i++ { - f.ignored = append(f.ignored, plugin.Host(ignore[i]).NormalizeExact()...) - } case "max_fails": if !c.NextArg() { return c.ArgErr() } - n, err := strconv.Atoi(c.Val()) + n, err := strconv.ParseInt(c.Val(), 10, 32) if err != nil { return err } if n < 0 { return fmt.Errorf("max_fails can't be negative: %d", n) } - f.maxfails = uint32(n) + maxFails := uint32(n) + cfg.MaxFails = &maxFails case "health_check": if !c.NextArg() { return c.ArgErr() @@ -179,15 +131,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - if dur < 0 { - return fmt.Errorf("health_check can't be negative: %d", dur) - } - f.hcInterval = dur + cfg.HealthCheck = &dur for c.NextArg() { switch hcOpts := c.Val(); hcOpts { case "no_rec": - f.opts.hcRecursionDesired = false + cfg.HealthCheckNoRec = true default: return fmt.Errorf("health_check: unknown option %s", hcOpts) } @@ -197,12 +146,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if c.NextArg() { return c.ArgErr() } - f.opts.forceTCP = true + cfg.ForceTCP = true case "prefer_udp": if c.NextArg() { return c.ArgErr() } - f.opts.preferUDP = true + cfg.PreferUDP = true case "tls": args := c.RemainingArgs() if len(args) > 3 { @@ -213,12 +162,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - f.tlsConfig = tlsConfig + cfg.TLSConfig = tlsConfig case "tls_servername": if !c.NextArg() { return c.ArgErr() } - f.tlsServerName = c.Val() + cfg.TLSServerName = c.Val() case "expire": if !c.NextArg() { return c.ArgErr() @@ -227,24 +176,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - if dur < 0 { - return fmt.Errorf("expire can't be negative: %s", dur) - } - f.expire = dur + cfg.Expire = &dur case "policy": if !c.NextArg() { return c.ArgErr() } - switch x := c.Val(); x { - case "random": - f.p = &random{} - case "round_robin": - f.p = &roundRobin{} - case "sequential": - f.p = &sequential{} - default: - return c.Errf("unknown policy '%s'", x) - } + cfg.Policy = c.Val() case "max_concurrent": if !c.NextArg() { return c.ArgErr() @@ -253,11 +190,8 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - if n < 0 { - return fmt.Errorf("max_concurrent can't be negative: %d", n) - } - f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val()) - f.maxConcurrent = int64(n) + maxConcurrent := int64(n) + cfg.MaxConcurrent = &maxConcurrent default: return c.Errf("unknown property '%s'", c.Val()) |