aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jonathan Dickinson <jonathan@dickinsons.co.za> 2016-10-27 09:09:16 +0200
committerGravatar Miek Gieben <miek@miek.nl> 2016-10-27 08:09:16 +0100
commit7ce71001226ecdbac32db3f405afc7ef35f9dbfe (patch)
treeccf3af2c6f5d59987d424000652069ec4991089d
parent219bfd0493124fc2f0170772833a094c3eb9b627 (diff)
downloadcoredns-7ce71001226ecdbac32db3f405afc7ef35f9dbfe.tar.gz
coredns-7ce71001226ecdbac32db3f405afc7ef35f9dbfe.tar.zst
coredns-7ce71001226ecdbac32db3f405afc7ef35f9dbfe.zip
- Adding tests for MX round-robin (#358)
- Implementing MX round-robin - Slight tidy
-rw-r--r--middleware/loadbalance/loadbalance.go22
-rw-r--r--middleware/loadbalance/loadbalance_test.go143
2 files changed, 121 insertions, 44 deletions
diff --git a/middleware/loadbalance/loadbalance.go b/middleware/loadbalance/loadbalance.go
index 59aad8a4f..7df0b31c6 100644
--- a/middleware/loadbalance/loadbalance.go
+++ b/middleware/loadbalance/loadbalance.go
@@ -28,6 +28,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
func roundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{}
address := []dns.RR{}
+ mx := []dns.RR{}
rest := []dns.RR{}
for _, r := range in {
switch r.Header().Rrtype {
@@ -35,17 +36,29 @@ func roundRobin(in []dns.RR) []dns.RR {
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
+ case dns.TypeMX:
+ mx = append(mx, r)
default:
rest = append(rest, r)
}
}
- switch l := len(address); l {
+ roundRobinShuffle(address)
+ roundRobinShuffle(mx)
+
+ out := append(cname, rest...)
+ out = append(out, address...)
+ out = append(out, mx...)
+ return out
+}
+
+func roundRobinShuffle(records []dns.RR) {
+ switch l := len(records); l {
case 0, 1:
break
case 2:
if dns.Id()%2 == 0 {
- address[0], address[1] = address[1], address[0]
+ records[0], records[1] = records[1], records[0]
}
default:
for j := 0; j < l*(int(dns.Id())%4+1); j++ {
@@ -54,12 +67,9 @@ func roundRobin(in []dns.RR) []dns.RR {
if q == p {
p = (p + 1) % l
}
- address[q], address[p] = address[p], address[q]
+ records[q], records[p] = records[p], records[q]
}
}
- out := append(cname, rest...)
- out = append(out, address...)
- return out
}
// Write implements the dns.ResponseWriter interface.
diff --git a/middleware/loadbalance/loadbalance_test.go b/middleware/loadbalance/loadbalance_test.go
index 5e240be13..2a5096004 100644
--- a/middleware/loadbalance/loadbalance_test.go
+++ b/middleware/loadbalance/loadbalance_test.go
@@ -16,44 +16,66 @@ func TestLoadBalance(t *testing.T) {
// the first X records must be cnames after this test
tests := []struct {
- answer []dns.RR
- extra []dns.RR
- cnameAnswer int
- cnameExtra int
+ answer []dns.RR
+ extra []dns.RR
+ cnameAnswer int
+ cnameExtra int
+ addressAnswer int
+ addressExtra int
+ mxAnswer int
+ mxExtra int
}{
{
answer: []dns.RR{
- newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
- newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
- newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
- newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
- newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
+ newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
+ newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ newMX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."),
+ newMX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."),
},
- cnameAnswer: 4,
+ cnameAnswer: 4,
+ addressAnswer: 1,
+ mxAnswer: 3,
},
{
answer: []dns.RR{
- newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
- newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
},
- cnameAnswer: 1,
+ cnameAnswer: 1,
+ addressAnswer: 1,
+ mxAnswer: 1,
},
{
answer: []dns.RR{
- newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
- newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
- newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
- newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."),
+ newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."),
},
extra: []dns.RR{
- newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
- newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"),
- newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
- newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
- newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
+ newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"),
+ newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."),
},
- cnameAnswer: 1,
- cnameExtra: 1,
+ cnameAnswer: 1,
+ cnameExtra: 1,
+ addressAnswer: 3,
+ addressExtra: 4,
+ mxAnswer: 3,
+ mxExtra: 3,
},
}
@@ -71,27 +93,71 @@ func TestLoadBalance(t *testing.T) {
continue
}
- cname := 0
- for _, r := range rec.Msg.Answer {
- if r.Header().Rrtype != dns.TypeCNAME {
- break
- }
- cname++
+
+ cname, address, mx, sorted := countRecords(rec.Msg.Answer)
+ if !sorted {
+ t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i)
}
if cname != test.cnameAnswer {
- t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname)
+ t.Errorf("Test %d: Expected %d CNAMEs in Answer, but got %d", i, test.cnameAnswer, cname)
}
- cname = 0
- for _, r := range rec.Msg.Extra {
- if r.Header().Rrtype != dns.TypeCNAME {
- break
- }
- cname++
+ if address != test.addressAnswer {
+ t.Errorf("Test %d: Expected %d A/AAAAs in Answer, but got %d", i, test.addressAnswer, address)
+ }
+ if mx != test.mxAnswer {
+ t.Errorf("Test %d: Expected %d MXs in Answer, but got %d", i, test.mxAnswer, mx)
+ }
+
+ cname, address, mx, sorted = countRecords(rec.Msg.Extra)
+ if !sorted {
+ t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Extra, but got mixed", i)
}
if cname != test.cnameExtra {
- t.Errorf("Test %d: Expected %d cname in Extra, but got %d", i, test.cnameExtra, cname)
+ t.Errorf("Test %d: Expected %d CNAMEs in Extra, but got %d", i, test.cnameAnswer, cname)
+ }
+ if address != test.addressExtra {
+ t.Errorf("Test %d: Expected %d A/AAAAs in Extra, but got %d", i, test.addressAnswer, address)
+ }
+ if mx != test.mxExtra {
+ t.Errorf("Test %d: Expected %d MXs in Extra, but got %d", i, test.mxAnswer, mx)
+ }
+ }
+}
+
+func countRecords(result []dns.RR) (cname int, address int, mx int, sorted bool) {
+ const (
+ Start = iota
+ CNAMERecords
+ ARecords
+ MXRecords
+ Any
+ )
+
+ // The order of the records is used to determine if the round-robin actually did anything.
+ sorted = true
+ cname = 0
+ address = 0
+ mx = 0
+ state := Start
+ for _, r := range result {
+ switch r.Header().Rrtype {
+ case dns.TypeCNAME:
+ sorted = sorted && state <= CNAMERecords
+ state = CNAMERecords
+ cname++
+ case dns.TypeA, dns.TypeAAAA:
+ sorted = sorted && state <= ARecords
+ state = ARecords
+ address++
+ case dns.TypeMX:
+ sorted = sorted && state <= MXRecords
+ state = MXRecords
+ mx++
+ default:
+ state = Any
}
}
+ return
}
func handler() middleware.Handler {
@@ -104,3 +170,4 @@ func handler() middleware.Handler {
func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
+func newMX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }