aboutsummaryrefslogtreecommitdiff
path: root/plugin/loadbalance/weighted_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/loadbalance/weighted_test.go')
-rw-r--r--plugin/loadbalance/weighted_test.go424
1 files changed, 424 insertions, 0 deletions
diff --git a/plugin/loadbalance/weighted_test.go b/plugin/loadbalance/weighted_test.go
new file mode 100644
index 000000000..e502c2772
--- /dev/null
+++ b/plugin/loadbalance/weighted_test.go
@@ -0,0 +1,424 @@
+package loadbalance
+
+import (
+ "context"
+ "errors"
+ "net"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ testutil "github.com/coredns/coredns/plugin/test"
+
+ "github.com/miekg/dns"
+)
+
+const oneDomainWRR = `
+w1,example.org
+192.168.1.15 10
+192.168.1.14 20
+`
+
+var testOneDomainWRR = map[string]weights{
+ "w1,example.org.": weights{
+ &weightItem{net.ParseIP("192.168.1.15"), uint8(10)},
+ &weightItem{net.ParseIP("192.168.1.14"), uint8(20)},
+ },
+}
+
+const twoDomainsWRR = `
+# domain 1
+w1.example.org
+192.168.1.15 10
+192.168.1.14 20
+
+# domain 2
+w2.example.org
+ # domain 3
+ w3.example.org
+ 192.168.2.16 11
+ 192.168.2.15 12
+ 192.168.2.14 13
+`
+
+var testTwoDomainsWRR = map[string]weights{
+ "w1.example.org.": weights{
+ &weightItem{net.ParseIP("192.168.1.15"), uint8(10)},
+ &weightItem{net.ParseIP("192.168.1.14"), uint8(20)},
+ },
+ "w2.example.org.": weights{},
+ "w3.example.org.": weights{
+ &weightItem{net.ParseIP("192.168.2.16"), uint8(11)},
+ &weightItem{net.ParseIP("192.168.2.15"), uint8(12)},
+ &weightItem{net.ParseIP("192.168.2.14"), uint8(13)},
+ },
+}
+
+const missingWeightWRR = `
+w1,example.org
+192.168.1.14
+192.168.1.15 20
+`
+
+const missingDomainWRR = `
+# missing domain
+192.168.1.14 10
+w2,example.org
+192.168.2.14 11
+192.168.2.15 12
+`
+
+const wrongIpWRR = `
+w1,example.org
+192.168.1.300 10
+`
+
+const wrongWeightWRR = `
+w1,example.org
+192.168.1.14 300
+`
+
+func TestWeightFileUpdate(t *testing.T) {
+ tests := []struct {
+ weightFilContent string
+ shouldErr bool
+ expectedDomains map[string]weights
+ expectedErrContent string // substring from the expected error. Empty for positive cases.
+ }{
+ // positive
+ {"", false, nil, ""},
+ {oneDomainWRR, false, testOneDomainWRR, ""},
+ {twoDomainsWRR, false, testTwoDomainsWRR, ""},
+ // negative
+ {missingWeightWRR, true, nil, "Wrong domain name"},
+ {missingDomainWRR, true, nil, "Missing domain name"},
+ {wrongIpWRR, true, nil, "Wrong IP address"},
+ {wrongWeightWRR, true, nil, "Wrong weight value"},
+ }
+
+ for i, test := range tests {
+ testFile, rm, err := testutil.TempFile(".", test.weightFilContent)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rm()
+ weighted := &weightedRR{fileName: testFile}
+ err = weighted.updateWeights()
+ if test.shouldErr && err == nil {
+ t.Errorf("Test %d: Expected error but found %s", i, err)
+ }
+ if err != nil {
+ if !test.shouldErr {
+ t.Errorf("Test %d: Expected no error but found error: %v", i, err)
+ }
+
+ if !strings.Contains(err.Error(), test.expectedErrContent) {
+ t.Errorf("Test %d: Expected error to contain: %v, found error: %v",
+ i, test.expectedErrContent, err)
+ }
+ }
+ if test.expectedDomains != nil {
+ if len(test.expectedDomains) != len(weighted.domains) {
+ t.Errorf("Test %d: Expected len(domains): %d but got %d",
+ i, len(test.expectedDomains), len(weighted.domains))
+ } else {
+ _ = checkDomainsWRR(t, i, test.expectedDomains, weighted.domains)
+ }
+ }
+ }
+}
+
+func checkDomainsWRR(t *testing.T, testIndex int, expectedDomains, domains map[string]weights) error {
+ var ret error
+ retError := errors.New("Check domains failed")
+ for dname, expectedWeights := range expectedDomains {
+ ws, ok := domains[dname]
+ if !ok {
+ t.Errorf("Test %d: Expected domain %s but not found it", testIndex, dname)
+ ret = retError
+ } else {
+ if len(expectedWeights) != len(ws) {
+ t.Errorf("Test %d: Expected len(weights): %d for domain %s but got %d",
+ testIndex, len(expectedWeights), dname, len(ws))
+ ret = retError
+ } else {
+ for i, w := range expectedWeights {
+ if !w.address.Equal(ws[i].address) || w.value != ws[i].value {
+ t.Errorf("Test %d: Weight list differs at index %d for domain %s. "+
+ "Expected: %v got: %v", testIndex, i, dname, expectedWeights[i], ws[i])
+ ret = retError
+ }
+ }
+ }
+ }
+ }
+
+ return ret
+}
+
+func TestPeriodicWeightUpdate(t *testing.T) {
+ testFile1, rm, err := testutil.TempFile(".", oneDomainWRR)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rm()
+ testFile2, rm, err := testutil.TempFile(".", twoDomainsWRR)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rm()
+
+ // configure weightedRR with "oneDomainWRR" weight file content
+ weighted := &weightedRR{fileName: testFile1}
+
+ err = weighted.updateWeights()
+ if err != nil {
+ t.Fatal(err)
+ } else {
+ err = checkDomainsWRR(t, 0, testOneDomainWRR, weighted.domains)
+ if err != nil {
+ t.Fatalf("Initial check domains failed")
+ }
+ }
+
+ // change weight file
+ weighted.fileName = testFile2
+ // start periodic update
+ weighted.reload = 10 * time.Millisecond
+ stopChan := make(chan bool)
+ weighted.periodicWeightUpdate(stopChan)
+ time.Sleep(20 * time.Millisecond)
+ // stop periodic update
+ close(stopChan)
+ // check updated config
+ weighted.mutex.Lock()
+ err = checkDomainsWRR(t, 0, testTwoDomainsWRR, weighted.domains)
+ weighted.mutex.Unlock()
+ if err != nil {
+ t.Fatalf("Final check domains failed")
+ }
+}
+
+// Fake random number generator for testing
+type fakeRandomGen struct {
+ expectedLimit uint
+ testIndex int
+ queryIndex int
+ randv uint
+ t *testing.T
+}
+
+func (r *fakeRandomGen) randInit() {
+}
+
+func (r *fakeRandomGen) randUint(limit uint) uint {
+ if limit != r.expectedLimit {
+ r.t.Errorf("Test %d query %d: Expected weights sum %d but got %d",
+ r.testIndex, r.queryIndex, r.expectedLimit, limit)
+ }
+ return r.randv
+}
+
+func TestLoadBalanceWRR(t *testing.T) {
+ type testQuery struct {
+ randv uint // fake random value for selecting the top IP
+ topIP string // top (first) address record in the answer
+ }
+
+ // domain maps to test
+ oneDomain := map[string]weights{
+ "endpoint.region2.skydns.test.": weights{
+ &weightItem{net.ParseIP("10.240.0.2"), uint8(3)},
+ &weightItem{net.ParseIP("10.240.0.1"), uint8(2)},
+ },
+ }
+ twoDomains := map[string]weights{
+ "endpoint.region2.skydns.test.": weights{
+ &weightItem{net.ParseIP("10.240.0.2"), uint8(5)},
+ &weightItem{net.ParseIP("10.240.0.1"), uint8(2)},
+ },
+ "endpoint.region1.skydns.test.": weights{
+ &weightItem{net.ParseIP("::2"), uint8(4)},
+ &weightItem{net.ParseIP("::1"), uint8(3)},
+ },
+ }
+
+ // the first X records must be cnames after this test
+ tests := []struct {
+ answer []dns.RR
+ extra []dns.RR
+ cnameAnswer int
+ cnameExtra int
+ addressAnswer int
+ addressExtra int
+ mxAnswer int
+ mxExtra int
+ domains map[string]weights
+ sumWeights uint // sum of weights in the answer
+ queries []testQuery
+ }{
+ {
+ answer: []dns.RR{
+ testutil.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
+ testutil.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ testutil.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
+ testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
+ testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"),
+ testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"),
+ testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ testutil.MX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."),
+ testutil.MX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."),
+ },
+ extra: []dns.RR{
+ testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
+ testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"),
+ testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"),
+ testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ },
+ cnameAnswer: 4,
+ cnameExtra: 1,
+ addressAnswer: 5,
+ addressExtra: 5,
+ mxAnswer: 3,
+ mxExtra: 1,
+ domains: twoDomains,
+ sumWeights: 15,
+ queries: []testQuery{
+ {0, "10.240.0.2"}, // domain 1 weight 5
+ {4, "10.240.0.2"}, // domain 1 weight 5
+ {5, "::2"}, // domain 2 weight 4
+ {8, "::2"}, // domain 2 weight 4
+ {9, "::1"}, // domain 2 weight 3
+ {11, "::1"}, // domain 2 weight 3
+ {12, "10.240.0.1"}, // domain 1 weight 2
+ {13, "10.240.0.1"}, // domain 1 weight 2
+ {14, "10.240.0.3"}, // domain 1 no weight -> default weight
+ },
+ },
+ {
+ answer: []dns.RR{
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
+ testutil.A("endpoint.region1.skydns.test. 300 IN A 10.240.0.3"),
+ },
+ cnameAnswer: 1,
+ addressAnswer: 3,
+ mxAnswer: 1,
+ domains: oneDomain,
+ sumWeights: 6,
+ queries: []testQuery{
+ {0, "10.240.0.2"}, // weight 3
+ {2, "10.240.0.2"}, // weight 3
+ {3, "10.240.0.1"}, // weight 2
+ {4, "10.240.0.1"}, // weight 2
+ {5, "10.240.0.3"}, // no domain -> default weight
+ },
+ },
+ {
+ answer: []dns.RR{
+ testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
+ testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ },
+ cnameAnswer: 1,
+ mxAnswer: 1,
+ domains: oneDomain,
+ queries: []testQuery{
+ {0, ""}, // no address records -> answer unaltered
+ },
+ },
+ }
+
+ testRand := &fakeRandomGen{t: t}
+ weighted := &weightedRR{randomGen: testRand}
+ shuffle := func(res *dns.Msg) *dns.Msg {
+ return weightedShuffle(res, weighted)
+ }
+ rm := LoadBalance{Next: handler(), shuffle: shuffle}
+
+ rec := dnstest.NewRecorder(&testutil.ResponseWriter{})
+
+ for i, test := range tests {
+ // set domain map for weighted round robin
+ weighted.domains = test.domains
+ testRand.testIndex = i
+ testRand.expectedLimit = test.sumWeights
+
+ for j, query := range test.queries {
+ req := new(dns.Msg)
+ req.SetQuestion("endpoint.region2.skydns.test", dns.TypeSRV)
+ req.Answer = test.answer
+ req.Extra = test.extra
+
+ // Set fake random number
+ testRand.randv = query.randv
+ testRand.queryIndex = j
+
+ _, err := rm.ServeDNS(context.TODO(), rec, req)
+ if err != nil {
+ t.Errorf("Test %d: Expected no error, but got %s", i, err)
+ continue
+ }
+
+ checkTopIP(t, i, j, rec.Msg.Answer, query.topIP)
+ checkTopIP(t, i, j, rec.Msg.Extra, query.topIP)
+
+ cname, address, mx, sorted := countRecords(rec.Msg.Answer)
+ if query.topIP != "" && !sorted {
+ t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j)
+ }
+ if cname != test.cnameAnswer {
+ t.Errorf("Test %d query %d: Expected %d CNAMEs in Answer, but got %d", i, j, test.cnameAnswer, cname)
+ }
+ if address != test.addressAnswer {
+ t.Errorf("Test %d query %d: Expected %d A/AAAAs in Answer, but got %d", i, j, test.addressAnswer, address)
+ }
+ if mx != test.mxAnswer {
+ t.Errorf("Test %d query %d: Expected %d MXs in Answer, but got %d", i, j, test.mxAnswer, mx)
+ }
+
+ cname, address, mx, sorted = countRecords(rec.Msg.Extra)
+ if query.topIP != "" && !sorted {
+ t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j)
+ }
+
+ if cname != test.cnameExtra {
+ t.Errorf("Test %d query %d: Expected %d CNAMEs in Extra, but got %d", i, j, test.cnameAnswer, cname)
+ }
+ if address != test.addressExtra {
+ t.Errorf("Test %d query %d: Expected %d A/AAAAs in Extra, but got %d", i, j, test.addressAnswer, address)
+ }
+ if mx != test.mxExtra {
+ t.Errorf("Test %d query %d: Expected %d MXs in Extra, but got %d", i, j, test.mxAnswer, mx)
+ }
+ }
+ }
+}
+
+func checkTopIP(t *testing.T, i, j int, result []dns.RR, expectedTopIP string) {
+ expected := net.ParseIP(expectedTopIP)
+ for _, r := range result {
+ switch r.Header().Rrtype {
+ case dns.TypeA:
+ ar := r.(*dns.A)
+ if !ar.A.Equal(expected) {
+ t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.A)
+ }
+ return
+ case dns.TypeAAAA:
+ ar := r.(*dns.AAAA)
+ if !ar.AAAA.Equal(expected) {
+ t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.AAAA)
+ }
+ return
+ }
+ }
+}