diff options
Diffstat (limited to 'plugin/loadbalance')
-rw-r--r-- | plugin/loadbalance/README.md | 22 | ||||
-rw-r--r-- | plugin/loadbalance/handler.go | 23 | ||||
-rw-r--r-- | plugin/loadbalance/loadbalance.go | 87 | ||||
-rw-r--r-- | plugin/loadbalance/loadbalance_test.go | 168 | ||||
-rw-r--r-- | plugin/loadbalance/setup.go | 26 |
5 files changed, 326 insertions, 0 deletions
diff --git a/plugin/loadbalance/README.md b/plugin/loadbalance/README.md new file mode 100644 index 000000000..1cce54ebf --- /dev/null +++ b/plugin/loadbalance/README.md @@ -0,0 +1,22 @@ +# loadbalance + +*loadbalance* acts as a round-robin DNS loadbalancer by randomizing the order of A and AAAA records + in the answer. + + See [Wikipedia](https://en.wikipedia.org/wiki/Round-robin_DNS) about the pros and cons on this + setup. It will take care to sort any CNAMEs before any address records, because some stub resolver + implementations (like glibc) are particular about that. + +## Syntax + +~~~ +loadbalance [POLICY] +~~~ + +* **POLICY** is how to balance, the default is "round_robin" + +## Examples + +~~~ +loadbalance round_robin +~~~ diff --git a/plugin/loadbalance/handler.go b/plugin/loadbalance/handler.go new file mode 100644 index 000000000..da4cf1549 --- /dev/null +++ b/plugin/loadbalance/handler.go @@ -0,0 +1,23 @@ +// Package loadbalance is plugin for rewriting responses to do "load balancing" +package loadbalance + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +// RoundRobin is plugin to rewrite responses for "load balancing". +type RoundRobin struct { + Next plugin.Handler +} + +// ServeDNS implements the plugin.Handler interface. +func (rr RoundRobin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + wrr := &RoundRobinResponseWriter{w} + return plugin.NextOrFailure(rr.Name(), rr.Next, ctx, wrr, r) +} + +// Name implements the Handler interface. +func (rr RoundRobin) Name() string { return "loadbalance" } diff --git a/plugin/loadbalance/loadbalance.go b/plugin/loadbalance/loadbalance.go new file mode 100644 index 000000000..7df0b31c6 --- /dev/null +++ b/plugin/loadbalance/loadbalance.go @@ -0,0 +1,87 @@ +// Package loadbalance shuffles A and AAAA records. +package loadbalance + +import ( + "log" + + "github.com/miekg/dns" +) + +// RoundRobinResponseWriter is a response writer that shuffles A and AAAA records. +type RoundRobinResponseWriter struct { + dns.ResponseWriter +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { + if res.Rcode != dns.RcodeSuccess { + return r.ResponseWriter.WriteMsg(res) + } + + res.Answer = roundRobin(res.Answer) + res.Ns = roundRobin(res.Ns) + res.Extra = roundRobin(res.Extra) + + return r.ResponseWriter.WriteMsg(res) +} + +func roundRobin(in []dns.RR) []dns.RR { + cname := []dns.RR{} + address := []dns.RR{} + mx := []dns.RR{} + rest := []dns.RR{} + for _, r := range in { + switch r.Header().Rrtype { + case dns.TypeCNAME: + cname = append(cname, r) + case dns.TypeA, dns.TypeAAAA: + address = append(address, r) + case dns.TypeMX: + mx = append(mx, r) + default: + rest = append(rest, r) + } + } + + roundRobinShuffle(address) + roundRobinShuffle(mx) + + out := append(cname, rest...) + out = append(out, address...) + out = append(out, mx...) + return out +} + +func roundRobinShuffle(records []dns.RR) { + switch l := len(records); l { + case 0, 1: + break + case 2: + if dns.Id()%2 == 0 { + records[0], records[1] = records[1], records[0] + } + default: + for j := 0; j < l*(int(dns.Id())%4+1); j++ { + q := int(dns.Id()) % l + p := int(dns.Id()) % l + if q == p { + p = (p + 1) % l + } + records[q], records[p] = records[p], records[q] + } + } +} + +// Write implements the dns.ResponseWriter interface. +func (r *RoundRobinResponseWriter) Write(buf []byte) (int, error) { + // Should we pack and unpack here to fiddle with the packet... Not likely. + log.Printf("[WARNING] RoundRobin called with Write: no shuffling records") + n, err := r.ResponseWriter.Write(buf) + return n, err +} + +// Hijack implements the dns.ResponseWriter interface. +func (r *RoundRobinResponseWriter) Hijack() { + r.ResponseWriter.Hijack() + return +} diff --git a/plugin/loadbalance/loadbalance_test.go b/plugin/loadbalance/loadbalance_test.go new file mode 100644 index 000000000..bde92b543 --- /dev/null +++ b/plugin/loadbalance/loadbalance_test.go @@ -0,0 +1,168 @@ +package loadbalance + +import ( + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnsrecorder" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +func TestLoadBalance(t *testing.T) { + rm := RoundRobin{Next: handler()} + + // the first X records must be cnames after this test + tests := []struct { + answer []dns.RR + extra []dns.RR + cnameAnswer int + cnameExtra int + addressAnswer int + addressExtra int + mxAnswer int + mxExtra int + }{ + { + answer: []dns.RR{ + test.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + test.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.MX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."), + test.MX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."), + }, + cnameAnswer: 4, + addressAnswer: 1, + mxAnswer: 3, + }, + { + answer: []dns.RR{ + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + }, + cnameAnswer: 1, + addressAnswer: 1, + mxAnswer: 1, + }, + { + answer: []dns.RR{ + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."), + }, + extra: []dns.RR{ + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.AAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + test.AAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."), + }, + cnameAnswer: 1, + cnameExtra: 1, + addressAnswer: 3, + addressExtra: 4, + mxAnswer: 3, + mxExtra: 3, + }, + } + + rec := dnsrecorder.New(&test.ResponseWriter{}) + + for i, test := range tests { + req := new(dns.Msg) + req.SetQuestion("region2.skydns.test.", dns.TypeSRV) + req.Answer = test.answer + req.Extra = test.extra + + _, err := rm.ServeDNS(context.TODO(), rec, req) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + + } + + cname, address, mx, sorted := countRecords(rec.Msg.Answer) + if !sorted { + t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i) + } + if cname != test.cnameAnswer { + t.Errorf("Test %d: Expected %d CNAMEs in Answer, but got %d", i, test.cnameAnswer, cname) + } + if address != test.addressAnswer { + t.Errorf("Test %d: Expected %d A/AAAAs in Answer, but got %d", i, test.addressAnswer, address) + } + if mx != test.mxAnswer { + t.Errorf("Test %d: Expected %d MXs in Answer, but got %d", i, test.mxAnswer, mx) + } + + cname, address, mx, sorted = countRecords(rec.Msg.Extra) + if !sorted { + t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Extra, but got mixed", i) + } + if cname != test.cnameExtra { + t.Errorf("Test %d: Expected %d CNAMEs in Extra, but got %d", i, test.cnameAnswer, cname) + } + if address != test.addressExtra { + t.Errorf("Test %d: Expected %d A/AAAAs in Extra, but got %d", i, test.addressAnswer, address) + } + if mx != test.mxExtra { + t.Errorf("Test %d: Expected %d MXs in Extra, but got %d", i, test.mxAnswer, mx) + } + } +} + +func countRecords(result []dns.RR) (cname int, address int, mx int, sorted bool) { + const ( + Start = iota + CNAMERecords + ARecords + MXRecords + Any + ) + + // The order of the records is used to determine if the round-robin actually did anything. + sorted = true + cname = 0 + address = 0 + mx = 0 + state := Start + for _, r := range result { + switch r.Header().Rrtype { + case dns.TypeCNAME: + sorted = sorted && state <= CNAMERecords + state = CNAMERecords + cname++ + case dns.TypeA, dns.TypeAAAA: + sorted = sorted && state <= ARecords + state = ARecords + address++ + case dns.TypeMX: + sorted = sorted && state <= MXRecords + state = MXRecords + mx++ + default: + state = Any + } + } + return +} + +func handler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + w.WriteMsg(r) + return dns.RcodeSuccess, nil + }) +} diff --git a/plugin/loadbalance/setup.go b/plugin/loadbalance/setup.go new file mode 100644 index 000000000..c2d90958e --- /dev/null +++ b/plugin/loadbalance/setup.go @@ -0,0 +1,26 @@ +package loadbalance + +import ( + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("loadbalance", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + for c.Next() { + // TODO(miek): block and option parsing + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return RoundRobin{Next: next} + }) + + return nil +} |