diff options
-rw-r--r-- | middleware/etcd/cname_test.go | 96 | ||||
-rw-r--r-- | middleware/loadbalance/loadbalance.go | 23 | ||||
-rw-r--r-- | middleware/loadbalance/loadbalance.md | 4 | ||||
-rw-r--r-- | middleware/loadbalance/loadbalance_test.go | 104 |
4 files changed, 216 insertions, 11 deletions
diff --git a/middleware/etcd/cname_test.go b/middleware/etcd/cname_test.go new file mode 100644 index 000000000..7d53bfef6 --- /dev/null +++ b/middleware/etcd/cname_test.go @@ -0,0 +1,96 @@ +// +build etcd + +package etcd + +// etcd needs to be running on http://127.0.0.1:2379 + +import ( + "testing" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/etcd/msg" + + "github.com/miekg/dns" +) + +// Check the ordering of returned cname. +func TestCnameLookup(t *testing.T) { + for _, serv := range servicesCname { + set(t, etc, serv.Key, 0, serv) + defer delete(t, etc, serv.Key) + } + for _, tc := range dnsTestCasesCname { + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(tc.Qname), tc.Qtype) + + rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) + _, err := etc.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("expected no error, got %v\n", err) + return + } + resp := rec.Msg() + + if resp.Rcode != tc.Rcode { + t.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode]) + t.Logf("%v\n", resp) + continue + } + + if len(resp.Answer) != len(tc.Answer) { + t.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer)) + t.Logf("%v\n", resp) + continue + } + if len(resp.Ns) != len(tc.Ns) { + t.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns)) + t.Logf("%v\n", resp) + continue + } + if len(resp.Extra) != len(tc.Extra) { + t.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra)) + t.Logf("%v\n", resp) + continue + } + + if !checkSection(t, tc, Answer, resp.Answer) { + t.Logf("%v\n", resp) + } + if !checkSection(t, tc, Ns, resp.Ns) { + t.Logf("%v\n", resp) + + } + if !checkSection(t, tc, Extra, resp.Extra) { + t.Logf("%v\n", resp) + } + } +} + +var servicesCname = []*msg.Service{ + {Host: "cname1.region2.skydns.test", Key: "a.server1.dev.region1.skydns.test."}, + {Host: "cname2.region2.skydns.test", Key: "cname1.region2.skydns.test."}, + {Host: "cname3.region2.skydns.test", Key: "cname2.region2.skydns.test."}, + {Host: "cname4.region2.skydns.test", Key: "cname3.region2.skydns.test."}, + {Host: "cname5.region2.skydns.test", Key: "cname4.region2.skydns.test."}, + {Host: "cname6.region2.skydns.test", Key: "cname5.region2.skydns.test."}, + {Host: "endpoint.region2.skydns.test", Key: "cname6.region2.skydns.test."}, + {Host: "10.240.0.1", Key: "endpoint.region2.skydns.test."}, +} + +var dnsTestCasesCname = []dnsTestCase{ + { + Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + newSRV("a.server1.dev.region1.skydns.test. 300 IN SRV 10 100 0 cname1.region2.skydns.test."), + }, + Extra: []dns.RR{ + newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + newCNAME("cname3.region2.skydns.test. 300 IN CNAME cname4.region2.skydns.test."), + newCNAME("cname4.region2.skydns.test. 300 IN CNAME cname5.region2.skydns.test."), + newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + }, + }, +} diff --git a/middleware/loadbalance/loadbalance.go b/middleware/loadbalance/loadbalance.go index c81ad0c8a..e1bee25fd 100644 --- a/middleware/loadbalance/loadbalance.go +++ b/middleware/loadbalance/loadbalance.go @@ -14,18 +14,21 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { if res.Rcode != dns.RcodeSuccess { return r.ResponseWriter.WriteMsg(res) } - if len(res.Answer) < 2 { // don't even bother - return r.ResponseWriter.WriteMsg(res) - } - // put CNAMEs first, randomize a/aaaa's and put packet back together. - // TODO(miek): check family and give v6 more prio? + res.Answer = roundRobin(res.Answer) + res.Extra = roundRobin(res.Extra) + + return r.ResponseWriter.WriteMsg(res) +} + +func roundRobin(in []dns.RR) []dns.RR { cname := []dns.RR{} address := []dns.RR{} rest := []dns.RR{} - for _, r := range res.Answer { + for _, r := range in { switch r.Header().Rrtype { case dns.TypeCNAME: + // d d d d DNAME and friends here as well? cname = append(cname, r) case dns.TypeA, dns.TypeAAAA: address = append(address, r) @@ -36,7 +39,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { switch l := len(address); l { case 0, 1: - return r.ResponseWriter.WriteMsg(res) + break case 2: if dns.Id()%2 == 0 { address[0], address[1] = address[1], address[0] @@ -51,9 +54,9 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { address[q], address[p] = address[p], address[q] } } - res.Answer = append(cname, rest...) - res.Answer = append(res.Answer, address...) - return r.ResponseWriter.WriteMsg(res) + out := append(cname, rest...) + out = append(out, address...) + return out } // Should we pack and unpack here to fiddle with the packet... Not likely. diff --git a/middleware/loadbalance/loadbalance.md b/middleware/loadbalance/loadbalance.md index 0e931fb53..5c381135d 100644 --- a/middleware/loadbalance/loadbalance.md +++ b/middleware/loadbalance/loadbalance.md @@ -4,13 +4,15 @@ message. See [Wikipedia](https://en.wikipedia.org/wiki/Round-robin_DNS) about the pros and cons on this setup. +It will take care to sort any CNAMEs before any address records. + ## Syntax ~~~ loadbalance [policy] ~~~ -* policy is how to balance, the default is "round_robin" +* `policy` is how to balance, the default is "round_robin" ## Examples diff --git a/middleware/loadbalance/loadbalance_test.go b/middleware/loadbalance/loadbalance_test.go new file mode 100644 index 000000000..dc027607c --- /dev/null +++ b/middleware/loadbalance/loadbalance_test.go @@ -0,0 +1,104 @@ +package loadbalance + +import ( + "testing" + + "github.com/miekg/coredns/middleware" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +func TestLoadBalance(t *testing.T) { + rm := RoundRobin{Next: handler()} + + // the first X records must be cnames after this test + tests := []struct { + answer []dns.RR + extra []dns.RR + cnameAnswer int + cnameExtra int + }{ + { + answer: []dns.RR{ + newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + }, + cnameAnswer: 4, + }, + { + answer: []dns.RR{ + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + }, + cnameAnswer: 1, + }, + { + answer: []dns.RR{ + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + }, + extra: []dns.RR{ + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"), + newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"), + }, + cnameAnswer: 1, + cnameExtra: 1, + }, + } + + rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) + + for i, test := range tests { + req := new(dns.Msg) + req.SetQuestion("region2.skydns.test.", dns.TypeSRV) + req.Answer = test.answer + req.Extra = test.extra + + _, err := rm.ServeDNS(context.TODO(), rec, req) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + + } + cname := 0 + for _, r := range rec.Msg().Answer { + if r.Header().Rrtype != dns.TypeCNAME { + break + } + cname++ + } + if cname != test.cnameAnswer { + t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname) + } + cname = 0 + for _, r := range rec.Msg().Extra { + if r.Header().Rrtype != dns.TypeCNAME { + break + } + cname++ + } + if cname != test.cnameExtra { + t.Errorf("Test %d: Expected %d cname in Extra, but got %d", i, test.cnameExtra, cname) + } + } +} + +func handler() middleware.Handler { + return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + w.WriteMsg(r) + return dns.RcodeSuccess, nil + }) +} + +func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } +func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) } +func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) } |