diff options
Diffstat (limited to 'middleware/proxy/upstream.go')
-rw-r--r-- | middleware/proxy/upstream.go | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go new file mode 100644 index 000000000..092e2351d --- /dev/null +++ b/middleware/proxy/upstream.go @@ -0,0 +1,235 @@ +package proxy + +import ( + "io" + "io/ioutil" + "net/http" + "path" + "strconv" + "time" + + "github.com/miekg/coredns/core/parse" + "github.com/miekg/coredns/middleware" +) + +var ( + supportedPolicies = make(map[string]func() Policy) +) + +type staticUpstream struct { + from string + // TODO(miek): allows use to added headers + proxyHeaders http.Header // TODO(miek): kill + Hosts HostPool + Policy Policy + + FailTimeout time.Duration + MaxFails int32 + HealthCheck struct { + Path string + Interval time.Duration + } + WithoutPathPrefix string + IgnoredSubPaths []string +} + +// NewStaticUpstreams parses the configuration input and sets up +// static upstreams for the proxy middleware. +func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { + var upstreams []Upstream + for c.Next() { + upstream := &staticUpstream{ + from: "", + proxyHeaders: make(http.Header), + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + + if !c.Args(&upstream.from) { + return upstreams, c.ArgErr() + } + to := c.RemainingArgs() + if len(to) == 0 { + return upstreams, c.ArgErr() + } + + for c.NextBlock() { + if err := parseBlock(&c, upstream); err != nil { + return upstreams, err + } + } + + upstream.Hosts = make([]*UpstreamHost, len(to)) + for i, host := range to { + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: upstream.FailTimeout, + Unhealthy: false, + ExtraHeaders: upstream.proxyHeaders, + CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= upstream.MaxFails && + upstream.MaxFails != 0 { + return true + } + return false + } + }(upstream), + WithoutPathPrefix: upstream.WithoutPathPrefix, + } + upstream.Hosts[i] = uh + } + + if upstream.HealthCheck.Path != "" { + go upstream.HealthCheckWorker(nil) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy +} + +func (u *staticUpstream) From() string { + return u.from +} + +func parseBlock(c *parse.Dispenser, u *staticUpstream) error { + switch c.Val() { + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + policyCreateFunc, ok := supportedPolicies[c.Val()] + if !ok { + return c.ArgErr() + } + u.Policy = policyCreateFunc() + case "fail_timeout": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.FailTimeout = dur + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + u.MaxFails = int32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + u.HealthCheck.Path = c.Val() + u.HealthCheck.Interval = 30 * time.Second + if c.NextArg() { + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.HealthCheck.Interval = dur + } + case "proxy_header": + var header, value string + if !c.Args(&header, &value) { + return c.ArgErr() + } + u.proxyHeaders.Add(header, value) + case "websocket": + u.proxyHeaders.Add("Connection", "{>Connection}") + u.proxyHeaders.Add("Upgrade", "{>Upgrade}") + case "without": + if !c.NextArg() { + return c.ArgErr() + } + u.WithoutPathPrefix = c.Val() + case "except": + ignoredPaths := c.RemainingArgs() + if len(ignoredPaths) == 0 { + return c.ArgErr() + } + u.IgnoredSubPaths = ignoredPaths + default: + return c.Errf("unknown property '%s'", c.Val()) + } + return nil +} + +func (u *staticUpstream) healthCheck() { + for _, host := range u.Hosts { + hostURL := host.Name + u.HealthCheck.Path + if r, err := http.Get(hostURL); err == nil { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 + } else { + host.Unhealthy = true + } + } +} + +func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { + ticker := time.NewTicker(u.HealthCheck.Interval) + u.healthCheck() + for { + select { + case <-ticker.C: + u.healthCheck() + case <-stop: + // TODO: the library should provide a stop channel and global + // waitgroup to allow goroutines started by plugins a chance + // to clean themselves up. + } + } +} + +func (u *staticUpstream) Select() *UpstreamHost { + pool := u.Hosts + if len(pool) == 1 { + if pool[0].Down() { + return nil + } + return pool[0] + } + allDown := true + for _, host := range pool { + if !host.Down() { + allDown = false + break + } + } + if allDown { + return nil + } + + if u.Policy == nil { + return (&Random{}).Select(pool) + } + return u.Policy.Select(pool) +} + +func (u *staticUpstream) IsAllowedPath(requestPath string) bool { + for _, ignoredSubPath := range u.IgnoredSubPaths { + if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) { + return false + } + } + return true +} |