diff options
Diffstat (limited to 'plugin/loadbalance/weighted_test.go')
-rw-r--r-- | plugin/loadbalance/weighted_test.go | 424 |
1 files changed, 424 insertions, 0 deletions
diff --git a/plugin/loadbalance/weighted_test.go b/plugin/loadbalance/weighted_test.go new file mode 100644 index 000000000..e502c2772 --- /dev/null +++ b/plugin/loadbalance/weighted_test.go @@ -0,0 +1,424 @@ +package loadbalance + +import ( + "context" + "errors" + "net" + "strings" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + testutil "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +const oneDomainWRR = ` +w1,example.org +192.168.1.15 10 +192.168.1.14 20 +` + +var testOneDomainWRR = map[string]weights{ + "w1,example.org.": weights{ + &weightItem{net.ParseIP("192.168.1.15"), uint8(10)}, + &weightItem{net.ParseIP("192.168.1.14"), uint8(20)}, + }, +} + +const twoDomainsWRR = ` +# domain 1 +w1.example.org +192.168.1.15 10 +192.168.1.14 20 + +# domain 2 +w2.example.org + # domain 3 + w3.example.org + 192.168.2.16 11 + 192.168.2.15 12 + 192.168.2.14 13 +` + +var testTwoDomainsWRR = map[string]weights{ + "w1.example.org.": weights{ + &weightItem{net.ParseIP("192.168.1.15"), uint8(10)}, + &weightItem{net.ParseIP("192.168.1.14"), uint8(20)}, + }, + "w2.example.org.": weights{}, + "w3.example.org.": weights{ + &weightItem{net.ParseIP("192.168.2.16"), uint8(11)}, + &weightItem{net.ParseIP("192.168.2.15"), uint8(12)}, + &weightItem{net.ParseIP("192.168.2.14"), uint8(13)}, + }, +} + +const missingWeightWRR = ` +w1,example.org +192.168.1.14 +192.168.1.15 20 +` + +const missingDomainWRR = ` +# missing domain +192.168.1.14 10 +w2,example.org +192.168.2.14 11 +192.168.2.15 12 +` + +const wrongIpWRR = ` +w1,example.org +192.168.1.300 10 +` + +const wrongWeightWRR = ` +w1,example.org +192.168.1.14 300 +` + +func TestWeightFileUpdate(t *testing.T) { + tests := []struct { + weightFilContent string + shouldErr bool + expectedDomains map[string]weights + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + {"", false, nil, ""}, + {oneDomainWRR, false, testOneDomainWRR, ""}, + {twoDomainsWRR, false, testTwoDomainsWRR, ""}, + // negative + {missingWeightWRR, true, nil, "Wrong domain name"}, + {missingDomainWRR, true, nil, "Missing domain name"}, + {wrongIpWRR, true, nil, "Wrong IP address"}, + {wrongWeightWRR, true, nil, "Wrong weight value"}, + } + + for i, test := range tests { + testFile, rm, err := testutil.TempFile(".", test.weightFilContent) + if err != nil { + t.Fatal(err) + } + defer rm() + weighted := &weightedRR{fileName: testFile} + err = weighted.updateWeights() + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s", i, err) + } + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found error: %v", i, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v", + i, test.expectedErrContent, err) + } + } + if test.expectedDomains != nil { + if len(test.expectedDomains) != len(weighted.domains) { + t.Errorf("Test %d: Expected len(domains): %d but got %d", + i, len(test.expectedDomains), len(weighted.domains)) + } else { + _ = checkDomainsWRR(t, i, test.expectedDomains, weighted.domains) + } + } + } +} + +func checkDomainsWRR(t *testing.T, testIndex int, expectedDomains, domains map[string]weights) error { + var ret error + retError := errors.New("Check domains failed") + for dname, expectedWeights := range expectedDomains { + ws, ok := domains[dname] + if !ok { + t.Errorf("Test %d: Expected domain %s but not found it", testIndex, dname) + ret = retError + } else { + if len(expectedWeights) != len(ws) { + t.Errorf("Test %d: Expected len(weights): %d for domain %s but got %d", + testIndex, len(expectedWeights), dname, len(ws)) + ret = retError + } else { + for i, w := range expectedWeights { + if !w.address.Equal(ws[i].address) || w.value != ws[i].value { + t.Errorf("Test %d: Weight list differs at index %d for domain %s. "+ + "Expected: %v got: %v", testIndex, i, dname, expectedWeights[i], ws[i]) + ret = retError + } + } + } + } + } + + return ret +} + +func TestPeriodicWeightUpdate(t *testing.T) { + testFile1, rm, err := testutil.TempFile(".", oneDomainWRR) + if err != nil { + t.Fatal(err) + } + defer rm() + testFile2, rm, err := testutil.TempFile(".", twoDomainsWRR) + if err != nil { + t.Fatal(err) + } + defer rm() + + // configure weightedRR with "oneDomainWRR" weight file content + weighted := &weightedRR{fileName: testFile1} + + err = weighted.updateWeights() + if err != nil { + t.Fatal(err) + } else { + err = checkDomainsWRR(t, 0, testOneDomainWRR, weighted.domains) + if err != nil { + t.Fatalf("Initial check domains failed") + } + } + + // change weight file + weighted.fileName = testFile2 + // start periodic update + weighted.reload = 10 * time.Millisecond + stopChan := make(chan bool) + weighted.periodicWeightUpdate(stopChan) + time.Sleep(20 * time.Millisecond) + // stop periodic update + close(stopChan) + // check updated config + weighted.mutex.Lock() + err = checkDomainsWRR(t, 0, testTwoDomainsWRR, weighted.domains) + weighted.mutex.Unlock() + if err != nil { + t.Fatalf("Final check domains failed") + } +} + +// Fake random number generator for testing +type fakeRandomGen struct { + expectedLimit uint + testIndex int + queryIndex int + randv uint + t *testing.T +} + +func (r *fakeRandomGen) randInit() { +} + +func (r *fakeRandomGen) randUint(limit uint) uint { + if limit != r.expectedLimit { + r.t.Errorf("Test %d query %d: Expected weights sum %d but got %d", + r.testIndex, r.queryIndex, r.expectedLimit, limit) + } + return r.randv +} + +func TestLoadBalanceWRR(t *testing.T) { + type testQuery struct { + randv uint // fake random value for selecting the top IP + topIP string // top (first) address record in the answer + } + + // domain maps to test + oneDomain := map[string]weights{ + "endpoint.region2.skydns.test.": weights{ + &weightItem{net.ParseIP("10.240.0.2"), uint8(3)}, + &weightItem{net.ParseIP("10.240.0.1"), uint8(2)}, + }, + } + twoDomains := map[string]weights{ + "endpoint.region2.skydns.test.": weights{ + &weightItem{net.ParseIP("10.240.0.2"), uint8(5)}, + &weightItem{net.ParseIP("10.240.0.1"), uint8(2)}, + }, + "endpoint.region1.skydns.test.": weights{ + &weightItem{net.ParseIP("::2"), uint8(4)}, + &weightItem{net.ParseIP("::1"), uint8(3)}, + }, + } + + // the first X records must be cnames after this test + tests := []struct { + answer []dns.RR + extra []dns.RR + cnameAnswer int + cnameExtra int + addressAnswer int + addressExtra int + mxAnswer int + mxExtra int + domains map[string]weights + sumWeights uint // sum of weights in the answer + queries []testQuery + }{ + { + answer: []dns.RR{ + testutil.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + testutil.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + testutil.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.MX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."), + testutil.MX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."), + }, + extra: []dns.RR{ + testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + }, + cnameAnswer: 4, + cnameExtra: 1, + addressAnswer: 5, + addressExtra: 5, + mxAnswer: 3, + mxExtra: 1, + domains: twoDomains, + sumWeights: 15, + queries: []testQuery{ + {0, "10.240.0.2"}, // domain 1 weight 5 + {4, "10.240.0.2"}, // domain 1 weight 5 + {5, "::2"}, // domain 2 weight 4 + {8, "::2"}, // domain 2 weight 4 + {9, "::1"}, // domain 2 weight 3 + {11, "::1"}, // domain 2 weight 3 + {12, "10.240.0.1"}, // domain 1 weight 2 + {13, "10.240.0.1"}, // domain 1 weight 2 + {14, "10.240.0.3"}, // domain 1 no weight -> default weight + }, + }, + { + answer: []dns.RR{ + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region1.skydns.test. 300 IN A 10.240.0.3"), + }, + cnameAnswer: 1, + addressAnswer: 3, + mxAnswer: 1, + domains: oneDomain, + sumWeights: 6, + queries: []testQuery{ + {0, "10.240.0.2"}, // weight 3 + {2, "10.240.0.2"}, // weight 3 + {3, "10.240.0.1"}, // weight 2 + {4, "10.240.0.1"}, // weight 2 + {5, "10.240.0.3"}, // no domain -> default weight + }, + }, + { + answer: []dns.RR{ + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + }, + cnameAnswer: 1, + mxAnswer: 1, + domains: oneDomain, + queries: []testQuery{ + {0, ""}, // no address records -> answer unaltered + }, + }, + } + + testRand := &fakeRandomGen{t: t} + weighted := &weightedRR{randomGen: testRand} + shuffle := func(res *dns.Msg) *dns.Msg { + return weightedShuffle(res, weighted) + } + rm := LoadBalance{Next: handler(), shuffle: shuffle} + + rec := dnstest.NewRecorder(&testutil.ResponseWriter{}) + + for i, test := range tests { + // set domain map for weighted round robin + weighted.domains = test.domains + testRand.testIndex = i + testRand.expectedLimit = test.sumWeights + + for j, query := range test.queries { + req := new(dns.Msg) + req.SetQuestion("endpoint.region2.skydns.test", dns.TypeSRV) + req.Answer = test.answer + req.Extra = test.extra + + // Set fake random number + testRand.randv = query.randv + testRand.queryIndex = j + + _, err := rm.ServeDNS(context.TODO(), rec, req) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + } + + checkTopIP(t, i, j, rec.Msg.Answer, query.topIP) + checkTopIP(t, i, j, rec.Msg.Extra, query.topIP) + + cname, address, mx, sorted := countRecords(rec.Msg.Answer) + if query.topIP != "" && !sorted { + t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j) + } + if cname != test.cnameAnswer { + t.Errorf("Test %d query %d: Expected %d CNAMEs in Answer, but got %d", i, j, test.cnameAnswer, cname) + } + if address != test.addressAnswer { + t.Errorf("Test %d query %d: Expected %d A/AAAAs in Answer, but got %d", i, j, test.addressAnswer, address) + } + if mx != test.mxAnswer { + t.Errorf("Test %d query %d: Expected %d MXs in Answer, but got %d", i, j, test.mxAnswer, mx) + } + + cname, address, mx, sorted = countRecords(rec.Msg.Extra) + if query.topIP != "" && !sorted { + t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j) + } + + if cname != test.cnameExtra { + t.Errorf("Test %d query %d: Expected %d CNAMEs in Extra, but got %d", i, j, test.cnameAnswer, cname) + } + if address != test.addressExtra { + t.Errorf("Test %d query %d: Expected %d A/AAAAs in Extra, but got %d", i, j, test.addressAnswer, address) + } + if mx != test.mxExtra { + t.Errorf("Test %d query %d: Expected %d MXs in Extra, but got %d", i, j, test.mxAnswer, mx) + } + } + } +} + +func checkTopIP(t *testing.T, i, j int, result []dns.RR, expectedTopIP string) { + expected := net.ParseIP(expectedTopIP) + for _, r := range result { + switch r.Header().Rrtype { + case dns.TypeA: + ar := r.(*dns.A) + if !ar.A.Equal(expected) { + t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.A) + } + return + case dns.TypeAAAA: + ar := r.(*dns.AAAA) + if !ar.AAAA.Equal(expected) { + t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.AAAA) + } + return + } + } +} |