diff options
author | 2016-03-18 20:57:35 +0000 | |
---|---|---|
committer | 2016-03-18 20:57:35 +0000 | |
commit | 3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d (patch) | |
tree | fae74c33cfed05de603785294593275f1901c861 /middleware/proxy | |
download | coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.gz coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.zst coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.zip |
First commit
Diffstat (limited to 'middleware/proxy')
-rw-r--r-- | middleware/proxy/policy.go | 101 | ||||
-rw-r--r-- | middleware/proxy/policy_test.go | 87 | ||||
-rw-r--r-- | middleware/proxy/proxy.go | 120 | ||||
-rw-r--r-- | middleware/proxy/proxy_test.go | 317 | ||||
-rw-r--r-- | middleware/proxy/reverseproxy.go | 36 | ||||
-rw-r--r-- | middleware/proxy/upstream.go | 235 | ||||
-rw-r--r-- | middleware/proxy/upstream_test.go | 83 |
7 files changed, 979 insertions, 0 deletions
diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go new file mode 100644 index 000000000..a2522bcb1 --- /dev/null +++ b/middleware/proxy/policy.go @@ -0,0 +1,101 @@ +package proxy + +import ( + "math/rand" + "sync/atomic" +) + +// HostPool is a collection of UpstreamHosts. +type HostPool []*UpstreamHost + +// Policy decides how a host will be selected from a pool. +type Policy interface { + Select(pool HostPool) *UpstreamHost +} + +func init() { + RegisterPolicy("random", func() Policy { return &Random{} }) + RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) +} + +// Random is a policy that selects up hosts from a pool at random. +type Random struct{} + +// Select selects an up host at random from the specified pool. +func (r *Random) Select(pool HostPool) *UpstreamHost { + // instead of just generating a random index + // this is done to prevent selecting a down host + var randHost *UpstreamHost + count := 0 + for _, host := range pool { + if host.Down() { + continue + } + count++ + if count == 1 { + randHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + randHost = host + } + } + } + return randHost +} + +// LeastConn is a policy that selects the host with the least connections. +type LeastConn struct{} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (r *LeastConn) Select(pool HostPool) *UpstreamHost { + var bestHost *UpstreamHost + count := 0 + leastConn := int64(1<<63 - 1) + for _, host := range pool { + if host.Down() { + continue + } + hostConns := host.Conns + if hostConns < leastConn { + bestHost = host + leastConn = hostConns + count = 1 + } else if hostConns == leastConn { + // randomly select host among hosts with least connections + count++ + if count == 1 { + bestHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + bestHost = host + } + } + } + } + return bestHost +} + +// RoundRobin is a policy that selects hosts based on round robin ordering. +type RoundRobin struct { + Robin uint32 +} + +// Select selects an up host from the pool using a round robin ordering scheme. +func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { + poolLen := uint32(len(pool)) + selection := atomic.AddUint32(&r.Robin, 1) % poolLen + host := pool[selection] + // if the currently selected host is down, just ffwd to up host + for i := uint32(1); host.Down() && i < poolLen; i++ { + host = pool[(selection+i)%poolLen] + } + if host.Down() { + return nil + } + return host +} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go new file mode 100644 index 000000000..8f4f1f792 --- /dev/null +++ b/middleware/proxy/policy_test.go @@ -0,0 +1,87 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +var workableServer *httptest.Server + +func TestMain(m *testing.M) { + workableServer = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // do nothing + })) + r := m.Run() + workableServer.Close() + os.Exit(r) +} + +type customPolicy struct{} + +func (r *customPolicy) Select(pool HostPool) *UpstreamHost { + return pool[0] +} + +func testPool() HostPool { + pool := []*UpstreamHost{ + { + Name: workableServer.URL, // this should resolve (healthcheck test) + }, + { + Name: "http://shouldnot.resolve", // this shouldn't + }, + { + Name: "http://C", + }, + } + return HostPool(pool) +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := &RoundRobin{} + h := rrPolicy.Select(pool) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + // mark host as down + pool[0].Unhealthy = true + h = rrPolicy.Select(pool) + if h != pool[1] { + t.Error("Expected third round robin host to be first host in the pool.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := &LeastConn{} + pool[0].Conns = 10 + pool[1].Conns = 10 + h := lcPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].Conns = 100 + h = lcPolicy.Select(pool) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} + +func TestCustomPolicy(t *testing.T) { + pool := testPool() + customPolicy := &customPolicy{} + h := customPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected custom policy host to be the first host.") + } +} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go new file mode 100644 index 000000000..169e41b61 --- /dev/null +++ b/middleware/proxy/proxy.go @@ -0,0 +1,120 @@ +// Package proxy is middleware that proxies requests. +package proxy + +import ( + "errors" + "net/http" + "sync/atomic" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +var errUnreachable = errors.New("unreachable backend") + +// Proxy represents a middleware instance that can proxy requests. +type Proxy struct { + Next middleware.Handler + Client Client + Upstreams []Upstream +} + +type Client struct { + UDP *dns.Client + TCP *dns.Client +} + +// Upstream manages a pool of proxy upstream hosts. Select should return a +// suitable upstream host, or nil if no such hosts are available. +type Upstream interface { + // The domain name this upstream host should be routed on. + From() string + // Selects an upstream host to be routed to. + Select() *UpstreamHost + // Checks if subpdomain is not an ignored. + IsAllowedPath(string) bool +} + +// UpstreamHostDownFunc can be used to customize how Down behaves. +type UpstreamHostDownFunc func(*UpstreamHost) bool + +// UpstreamHost represents a single proxy upstream +type UpstreamHost struct { + Conns int64 // must be first field to be 64-bit aligned on 32-bit systems + Name string // IP address (and port) of this upstream host + Fails int32 + FailTimeout time.Duration + Unhealthy bool + ExtraHeaders http.Header + CheckDown UpstreamHostDownFunc + WithoutPathPrefix string +} + +// Down checks whether the upstream host is down or not. +// Down will try to use uh.CheckDown first, and will fall +// back to some default criteria if necessary. +func (uh *UpstreamHost) Down() bool { + if uh.CheckDown == nil { + // Default settings + return uh.Unhealthy || uh.Fails > 0 + } + return uh.CheckDown(uh) +} + +// tryDuration is how long to try upstream hosts; failures result in +// immediate retries until this duration ends or we get a nil host. +var tryDuration = 60 * time.Second + +// ServeDNS satisfies the middleware.Handler interface. +func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + for _, upstream := range p.Upstreams { + // allowed bla bla bla TODO(miek): fix full proxy spec from caddy + start := time.Now() + + // Since Select() should give us "up" hosts, keep retrying + // hosts until timeout (or until we get a nil host). + for time.Now().Sub(start) < tryDuration { + host := upstream.Select() + if host == nil { + return dns.RcodeServerFailure, errUnreachable + } + // TODO(miek): PORT! + reverseproxy := ReverseProxy{Host: host.Name, Client: p.Client} + + atomic.AddInt64(&host.Conns, 1) + backendErr := reverseproxy.ServeDNS(w, r, nil) + atomic.AddInt64(&host.Conns, -1) + if backendErr == nil { + return 0, nil + } + timeout := host.FailTimeout + if timeout == 0 { + timeout = 10 * time.Second + } + atomic.AddInt32(&host.Fails, 1) + go func(host *UpstreamHost, timeout time.Duration) { + time.Sleep(timeout) + atomic.AddInt32(&host.Fails, -1) + }(host, timeout) + } + return dns.RcodeServerFailure, errUnreachable + } + return p.Next.ServeDNS(w, r) +} + +func Clients() Client { + udp := newClient("udp", defaultTimeout) + tcp := newClient("tcp", defaultTimeout) + return Client{UDP: udp, TCP: tcp} +} + +// newClient returns a new client for proxy requests. +func newClient(net string, timeout time.Duration) *dns.Client { + if timeout == 0 { + timeout = defaultTimeout + } + return &dns.Client{Net: net, ReadTimeout: timeout, WriteTimeout: timeout, SingleInflight: true} +} + +const defaultTimeout = 5 * time.Second diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go new file mode 100644 index 000000000..8066874d2 --- /dev/null +++ b/middleware/proxy/proxy_test.go @@ -0,0 +1,317 @@ +package proxy + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +func init() { + tryDuration = 50 * time.Millisecond // prevent tests from hanging +} + +func TestReverseProxy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Expected backend to receive request, but it didn't") + } +} + +func TestReverseProxyInsecureSkipVerify(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, true)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't") + } +} + +func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { + // No-op websocket backend simply allows the WS connection to be + // accepted then it will be immediately closed. Perfect for testing. + wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) + defer wsNop.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsNop.URL) + + // Create client request + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + r.Header = http.Header{ + "Connection": {"Upgrade"}, + "Upgrade": {"websocket"}, + "Origin": {wsNop.URL}, + "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, + "Sec-WebSocket-Version": {"13"}, + } + + // Capture the request + w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)} + + // Booya! Do the test. + p.ServeHTTP(w, r) + + // Make sure the backend accepted the WS connection. + // Mostly interested in the Upgrade and Connection response headers + // and the 101 status code. + expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n") + actual := w.fakeConn.writeBuf.Bytes() + if !bytes.Equal(actual, expected) { + t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) + } +} + +func TestWebSocketReverseProxyFromWSClient(t *testing.T) { + // Echo server allows us to test that socket bytes are properly + // being proxied. + wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsEcho.URL) + + // This is a full end-end test, so the proxy handler + // has to be part of a server listening on a port. Our + // WS client will connect to this test server, not + // the echo client directly. + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "http://", "ws://", 1) + ws, err := websocket.Dial(url, "", echoProxy.URL) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + websocket.Message.Send(ws, trialMsg) + + // It should be echoed back to us + var actualMsg string + websocket.Message.Receive(ws, &actualMsg) + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func TestUnixSocketProxy(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + + trialMsg := "Is it working?" + + var proxySuccess bool + + // This is our fake "application" we want to proxy to + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Request was proxied when this is called + proxySuccess = true + + fmt.Fprint(w, trialMsg) + })) + + // Get absolute path for unix: socket + socketPath, err := filepath.Abs("./test_socket") + if err != nil { + t.Fatalf("Unable to get absolute path: %v", err) + } + + // Change httptest.Server listener to listen to unix: socket + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Unable to listen: %v", err) + } + ts.Listener = ln + + ts.Start() + defer ts.Close() + + url := strings.Replace(ts.URL, "http://", "unix:", 1) + p := newWebSocketTestProxy(url) + + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + res, err := http.Get(echoProxy.URL) + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + actualMsg := fmt.Sprintf("%s", greeting) + + if !proxySuccess { + t.Errorf("Expected request to be proxied, but it wasn't") + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func newFakeUpstream(name string, insecure bool) *fakeUpstream { + uri, _ := url.Parse(name) + u := &fakeUpstream{ + name: name, + host: &UpstreamHost{ + Name: name, + ReverseProxy: NewSingleHostReverseProxy(uri, ""), + }, + } + if insecure { + u.host.ReverseProxy.Transport = InsecureTransport + } + return u +} + +type fakeUpstream struct { + name string + host *UpstreamHost +} + +func (u *fakeUpstream) From() string { + return "/" +} + +func (u *fakeUpstream) Select() *UpstreamHost { + return u.host +} + +func (u *fakeUpstream) IsAllowedPath(requestPath string) bool { + return true +} + +// newWebSocketTestProxy returns a test proxy that will +// redirect to the specified backendAddr. The function +// also sets up the rules/environment for testing WebSocket +// proxy. +func newWebSocketTestProxy(backendAddr string) *Proxy { + return &Proxy{ + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}}, + } +} + +type fakeWsUpstream struct { + name string +} + +func (u *fakeWsUpstream) From() string { + return "/" +} + +func (u *fakeWsUpstream) Select() *UpstreamHost { + uri, _ := url.Parse(u.name) + return &UpstreamHost{ + Name: u.name, + ReverseProxy: NewSingleHostReverseProxy(uri, ""), + ExtraHeaders: http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}}, + } +} + +func (u *fakeWsUpstream) IsAllowedPath(requestPath string) bool { + return true +} + +// recorderHijacker is a ResponseRecorder that can +// be hijacked. +type recorderHijacker struct { + *httptest.ResponseRecorder + fakeConn *fakeConn +} + +func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rh.fakeConn, nil, nil +} + +type fakeConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +func (c *fakeConn) LocalAddr() net.Addr { return nil } +func (c *fakeConn) RemoteAddr() net.Addr { return nil } +func (c *fakeConn) SetDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } +func (c *fakeConn) Close() error { return nil } +func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } +func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go new file mode 100644 index 000000000..6d27da042 --- /dev/null +++ b/middleware/proxy/reverseproxy.go @@ -0,0 +1,36 @@ +// Package proxy is middleware that proxies requests. +package proxy + +import ( + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +type ReverseProxy struct { + Host string + Client Client +} + +func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error { + // TODO(miek): use extra! + var ( + reply *dns.Msg + err error + ) + context := middleware.Context{W: w, Req: r} + + // tls+tcp ? + if context.Proto() == "tcp" { + reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) + } else { + reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) + } + + if err != nil { + return err + } + reply.Compress = true + reply.Id = r.Id + w.WriteMsg(reply) + return nil +} diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go new file mode 100644 index 000000000..092e2351d --- /dev/null +++ b/middleware/proxy/upstream.go @@ -0,0 +1,235 @@ +package proxy + +import ( + "io" + "io/ioutil" + "net/http" + "path" + "strconv" + "time" + + "github.com/miekg/coredns/core/parse" + "github.com/miekg/coredns/middleware" +) + +var ( + supportedPolicies = make(map[string]func() Policy) +) + +type staticUpstream struct { + from string + // TODO(miek): allows use to added headers + proxyHeaders http.Header // TODO(miek): kill + Hosts HostPool + Policy Policy + + FailTimeout time.Duration + MaxFails int32 + HealthCheck struct { + Path string + Interval time.Duration + } + WithoutPathPrefix string + IgnoredSubPaths []string +} + +// NewStaticUpstreams parses the configuration input and sets up +// static upstreams for the proxy middleware. +func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { + var upstreams []Upstream + for c.Next() { + upstream := &staticUpstream{ + from: "", + proxyHeaders: make(http.Header), + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + + if !c.Args(&upstream.from) { + return upstreams, c.ArgErr() + } + to := c.RemainingArgs() + if len(to) == 0 { + return upstreams, c.ArgErr() + } + + for c.NextBlock() { + if err := parseBlock(&c, upstream); err != nil { + return upstreams, err + } + } + + upstream.Hosts = make([]*UpstreamHost, len(to)) + for i, host := range to { + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: upstream.FailTimeout, + Unhealthy: false, + ExtraHeaders: upstream.proxyHeaders, + CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= upstream.MaxFails && + upstream.MaxFails != 0 { + return true + } + return false + } + }(upstream), + WithoutPathPrefix: upstream.WithoutPathPrefix, + } + upstream.Hosts[i] = uh + } + + if upstream.HealthCheck.Path != "" { + go upstream.HealthCheckWorker(nil) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy +} + +func (u *staticUpstream) From() string { + return u.from +} + +func parseBlock(c *parse.Dispenser, u *staticUpstream) error { + switch c.Val() { + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + policyCreateFunc, ok := supportedPolicies[c.Val()] + if !ok { + return c.ArgErr() + } + u.Policy = policyCreateFunc() + case "fail_timeout": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.FailTimeout = dur + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + u.MaxFails = int32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + u.HealthCheck.Path = c.Val() + u.HealthCheck.Interval = 30 * time.Second + if c.NextArg() { + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.HealthCheck.Interval = dur + } + case "proxy_header": + var header, value string + if !c.Args(&header, &value) { + return c.ArgErr() + } + u.proxyHeaders.Add(header, value) + case "websocket": + u.proxyHeaders.Add("Connection", "{>Connection}") + u.proxyHeaders.Add("Upgrade", "{>Upgrade}") + case "without": + if !c.NextArg() { + return c.ArgErr() + } + u.WithoutPathPrefix = c.Val() + case "except": + ignoredPaths := c.RemainingArgs() + if len(ignoredPaths) == 0 { + return c.ArgErr() + } + u.IgnoredSubPaths = ignoredPaths + default: + return c.Errf("unknown property '%s'", c.Val()) + } + return nil +} + +func (u *staticUpstream) healthCheck() { + for _, host := range u.Hosts { + hostURL := host.Name + u.HealthCheck.Path + if r, err := http.Get(hostURL); err == nil { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 + } else { + host.Unhealthy = true + } + } +} + +func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { + ticker := time.NewTicker(u.HealthCheck.Interval) + u.healthCheck() + for { + select { + case <-ticker.C: + u.healthCheck() + case <-stop: + // TODO: the library should provide a stop channel and global + // waitgroup to allow goroutines started by plugins a chance + // to clean themselves up. + } + } +} + +func (u *staticUpstream) Select() *UpstreamHost { + pool := u.Hosts + if len(pool) == 1 { + if pool[0].Down() { + return nil + } + return pool[0] + } + allDown := true + for _, host := range pool { + if !host.Down() { + allDown = false + break + } + } + if allDown { + return nil + } + + if u.Policy == nil { + return (&Random{}).Select(pool) + } + return u.Policy.Select(pool) +} + +func (u *staticUpstream) IsAllowedPath(requestPath string) bool { + for _, ignoredSubPath := range u.IgnoredSubPaths { + if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) { + return false + } + } + return true +} diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go new file mode 100644 index 000000000..5b2fdb1da --- /dev/null +++ b/middleware/proxy/upstream_test.go @@ -0,0 +1,83 @@ +package proxy + +import ( + "testing" + "time" +) + +func TestHealthCheck(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool(), + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.healthCheck() + if upstream.Hosts[0].Down() { + t.Error("Expected first host in testpool to not fail healthcheck.") + } + if !upstream.Hosts[1].Down() { + t.Error("Expected second host in testpool to fail healthcheck.") + } +} + +func TestSelect(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool()[:3], + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.Hosts[0].Unhealthy = true + upstream.Hosts[1].Unhealthy = true + upstream.Hosts[2].Unhealthy = true + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all host are down") + } + upstream.Hosts[2].Unhealthy = false + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } +} + +func TestRegisterPolicy(t *testing.T) { + name := "custom" + customPolicy := &customPolicy{} + RegisterPolicy(name, func() Policy { return customPolicy }) + if _, ok := supportedPolicies[name]; !ok { + t.Error("Expected supportedPolicies to have a custom policy.") + } + +} + +func TestAllowedPaths(t *testing.T) { + upstream := &staticUpstream{ + from: "/proxy", + IgnoredSubPaths: []string{"/download", "/static"}, + } + tests := []struct { + url string + expected bool + }{ + {"/proxy", true}, + {"/proxy/dl", true}, + {"/proxy/download", false}, + {"/proxy/download/static", false}, + {"/proxy/static", false}, + {"/proxy/static/download", false}, + {"/proxy/something/download", true}, + {"/proxy/something/static", true}, + {"/proxy//static", false}, + {"/proxy//static//download", false}, + {"/proxy//download", false}, + } + + for i, test := range tests { + isAllowed := upstream.IsAllowedPath(test.url) + if test.expected != isAllowed { + t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed) + } + } +} |