diff options
Diffstat (limited to 'plugin')
-rw-r--r-- | plugin/loadbalance/README.md | 61 | ||||
-rw-r--r-- | plugin/loadbalance/handler.go | 13 | ||||
-rw-r--r-- | plugin/loadbalance/loadbalance.go | 25 | ||||
-rw-r--r-- | plugin/loadbalance/loadbalance_test.go | 6 | ||||
-rw-r--r-- | plugin/loadbalance/setup.go | 82 | ||||
-rw-r--r-- | plugin/loadbalance/setup_test.go | 71 | ||||
-rw-r--r-- | plugin/loadbalance/weighted.go | 329 | ||||
-rw-r--r-- | plugin/loadbalance/weighted_test.go | 424 |
8 files changed, 975 insertions, 36 deletions
diff --git a/plugin/loadbalance/README.md b/plugin/loadbalance/README.md index 81a6580c4..fe29b19fc 100644 --- a/plugin/loadbalance/README.md +++ b/plugin/loadbalance/README.md @@ -16,10 +16,43 @@ implementations (like glibc) are particular about that. ## Syntax ~~~ -loadbalance [POLICY] +loadbalance [round_robin | weighted WEIGHTFILE] { + reload DURATION +} +~~~ +* `round_robin` policy randomizes the order of A, AAAA, and MX records applying a uniform probability distribution. This is the default load balancing policy. + +* `weighted` policy assigns weight values to IPs to control the relative likelihood of particular IPs to be returned as the first +(top) A/AAAA record in the answer. Note that it does not shuffle all the records in the answer, it is only concerned about the first A/AAAA record +returned in the answer. + + * **WEIGHTFILE** is the file containing the weight values assigned to IPs for various domain names. If the path is relative, the path from the **root** plugin will be prepended to it. The format is explained below in the *Weightfile* section. + + * **DURATION** interval to reload `WEIGHTFILE` and update weight assignments if there are changes in the file. The default value is `30s`. A value of `0s` means to not scan for changes and reload. + + +## Weightfile + +The generic weight file syntax: + +~~~ +# Comment lines are ignored + +domain-name1 +ip11 weight11 +ip12 weight12 +ip13 weight13 + +domain-name2 +ip21 weight21 +ip22 weight22 +# ... etc. ~~~ -* **POLICY** is how to balance. The default, and only option, is "round_robin". +where `ipXY` is an IP address for `domain-nameX` and `weightXY` is the weight value associated with that IP. The weight values are in the range of [1,255]. + +The `weighted` policy selects one of the address record in the result list and moves it to the top (first) position in the list. The random selection takes into account the weight values assigned to the addresses in the weight file. If an address in the result list is associated with no weight value in the weight file then the default weight value "1" is assumed for it when the selection is performed. + ## Examples @@ -31,3 +64,27 @@ Load balance replies coming back from Google Public DNS: forward . 8.8.8.8 8.8.4.4 } ~~~ + +Use the `weighted` strategy to load balance replies supplied by the **file** plugin. We assign weight vales `3`, `1` and `2` to the IPs `100.64.1.1`, `100.64.1.2` and `100.64.1.3`, respectively. These IPs are addresses in A records for the domain name `www.example.com` defined in the `./db.example.com` zone file. The ratio between the number of answers in which `100.64.1.1`, `100.64.1.2` or `100.64.1.3` is in the top (first) A record should converge to `3 : 1 : 2`. (E.g. there should be twice as many answers with `100.64.1.3` in the top A record than with `100.64.1.2`). +Corefile: + +~~~ corefile +example.com { + file ./db.example.com { + reload 10s + } + loadbalance weighted ./db.example.com.weights { + reload 10s + } +} +~~~ + +weight file `./db.example.com.weights`: + +~~~ +www.example.com +100.64.1.1 3 +100.64.1.2 1 +100.64.1.3 2 +~~~ + diff --git a/plugin/loadbalance/handler.go b/plugin/loadbalance/handler.go index ac046c8d0..8b84e1c5c 100644 --- a/plugin/loadbalance/handler.go +++ b/plugin/loadbalance/handler.go @@ -10,15 +10,16 @@ import ( ) // RoundRobin is a plugin to rewrite responses for "load balancing". -type RoundRobin struct { - Next plugin.Handler +type LoadBalance struct { + Next plugin.Handler + shuffle func(*dns.Msg) *dns.Msg } // ServeDNS implements the plugin.Handler interface. -func (rr RoundRobin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - wrr := &RoundRobinResponseWriter{w} - return plugin.NextOrFailure(rr.Name(), rr.Next, ctx, wrr, r) +func (lb LoadBalance) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rw := &LoadBalanceResponseWriter{ResponseWriter: w, shuffle: lb.shuffle} + return plugin.NextOrFailure(lb.Name(), lb.Next, ctx, rw, r) } // Name implements the Handler interface. -func (rr RoundRobin) Name() string { return "loadbalance" } +func (lb LoadBalance) Name() string { return "loadbalance" } diff --git a/plugin/loadbalance/loadbalance.go b/plugin/loadbalance/loadbalance.go index 966121d6b..f2a1caed0 100644 --- a/plugin/loadbalance/loadbalance.go +++ b/plugin/loadbalance/loadbalance.go @@ -5,11 +5,19 @@ import ( "github.com/miekg/dns" ) -// RoundRobinResponseWriter is a response writer that shuffles A, AAAA and MX records. -type RoundRobinResponseWriter struct{ dns.ResponseWriter } +const ( + ramdomShufflePolicy = "round_robin" + weightedRoundRobinPolicy = "weighted" +) + +// LoadBalanceResponseWriter is a response writer that shuffles A, AAAA and MX records. +type LoadBalanceResponseWriter struct { + dns.ResponseWriter + shuffle func(*dns.Msg) *dns.Msg +} // WriteMsg implements the dns.ResponseWriter interface. -func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { +func (r *LoadBalanceResponseWriter) WriteMsg(res *dns.Msg) error { if res.Rcode != dns.RcodeSuccess { return r.ResponseWriter.WriteMsg(res) } @@ -18,11 +26,14 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { return r.ResponseWriter.WriteMsg(res) } + return r.ResponseWriter.WriteMsg(r.shuffle(res)) +} + +func randomShuffle(res *dns.Msg) *dns.Msg { res.Answer = roundRobin(res.Answer) res.Ns = roundRobin(res.Ns) res.Extra = roundRobin(res.Extra) - - return r.ResponseWriter.WriteMsg(res) + return res } func roundRobin(in []dns.RR) []dns.RR { @@ -72,9 +83,9 @@ func roundRobinShuffle(records []dns.RR) { } // Write implements the dns.ResponseWriter interface. -func (r *RoundRobinResponseWriter) Write(buf []byte) (int, error) { +func (r *LoadBalanceResponseWriter) Write(buf []byte) (int, error) { // Should we pack and unpack here to fiddle with the packet... Not likely. - log.Warning("RoundRobin called with Write: not shuffling records") + log.Warning("LoadBalance called with Write: not shuffling records") n, err := r.ResponseWriter.Write(buf) return n, err } diff --git a/plugin/loadbalance/loadbalance_test.go b/plugin/loadbalance/loadbalance_test.go index 6f50b6e1a..c46d96842 100644 --- a/plugin/loadbalance/loadbalance_test.go +++ b/plugin/loadbalance/loadbalance_test.go @@ -11,8 +11,8 @@ import ( "github.com/miekg/dns" ) -func TestLoadBalance(t *testing.T) { - rm := RoundRobin{Next: handler()} +func TestLoadBalanceRandom(t *testing.T) { + rm := LoadBalance{Next: handler(), shuffle: randomShuffle} // the first X records must be cnames after this test tests := []struct { @@ -124,7 +124,7 @@ func TestLoadBalance(t *testing.T) { } func TestLoadBalanceXFR(t *testing.T) { - rm := RoundRobin{Next: handler()} + rm := LoadBalance{Next: handler()} answer := []dns.RR{ test.SOA("skydns.test. 30 IN SOA ns.dns.skydns.test. hostmaster.skydns.test. 1542756695 7200 1800 86400 30"), diff --git a/plugin/loadbalance/setup.go b/plugin/loadbalance/setup.go index d8f273aaa..5706aeb41 100644 --- a/plugin/loadbalance/setup.go +++ b/plugin/loadbalance/setup.go @@ -1,43 +1,103 @@ package loadbalance import ( + "errors" "fmt" + "path/filepath" + "time" "github.com/coredns/caddy" "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" ) var log = clog.NewWithPlugin("loadbalance") +var errOpen = errors.New("Weight file open error") func init() { plugin.Register("loadbalance", setup) } +type lbFuncs struct { + shuffleFunc func(*dns.Msg) *dns.Msg + onStartUpFunc func() error + onShutdownFunc func() error + weighted *weightedRR // used in unit tests only +} + func setup(c *caddy.Controller) error { - err := parse(c) + //shuffleFunc, startUpFunc, shutdownFunc, err := parse(c) + lb, err := parse(c) if err != nil { return plugin.Error("loadbalance", err) } + if lb.onStartUpFunc != nil { + c.OnStartup(lb.onStartUpFunc) + } + if lb.onShutdownFunc != nil { + c.OnShutdown(lb.onShutdownFunc) + } dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { - return RoundRobin{Next: next} + return LoadBalance{Next: next, shuffle: lb.shuffleFunc} }) return nil } -func parse(c *caddy.Controller) error { +// func parse(c *caddy.Controller) (string, *weightedRR, error) { +func parse(c *caddy.Controller) (*lbFuncs, error) { + config := dnsserver.GetConfig(c) + for c.Next() { args := c.RemainingArgs() - switch len(args) { - case 0: - return nil - case 1: - if args[0] != "round_robin" { - return fmt.Errorf("unknown policy: %s", args[0]) + if len(args) == 0 { + return &lbFuncs{shuffleFunc: randomShuffle}, nil + } + switch args[0] { + case ramdomShufflePolicy: + if len(args) > 1 { + return nil, c.Errf("unknown property for %s", args[0]) + } + return &lbFuncs{shuffleFunc: randomShuffle}, nil + case weightedRoundRobinPolicy: + if len(args) < 2 { + return nil, c.Err("missing weight file argument") + } + + if len(args) > 2 { + return nil, c.Err("unexpected argument(s)") + } + + weightFileName := args[1] + if !filepath.IsAbs(weightFileName) && config.Root != "" { + weightFileName = filepath.Join(config.Root, weightFileName) + } + reload := 30 * time.Second // default reload period + for c.NextBlock() { + switch c.Val() { + case "reload": + t := c.RemainingArgs() + if len(t) < 1 { + return nil, c.Err("reload duration value is missing") + } + if len(t) > 1 { + return nil, c.Err("unexpected argument") + } + var err error + reload, err = time.ParseDuration(t[0]) + if err != nil { + return nil, c.Errf("invalid reload duration '%s'", t[0]) + } + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } } - return nil + return createWeightedFuncs(weightFileName, reload), nil + default: + return nil, fmt.Errorf("unknown policy: %s", args[0]) } } - return c.ArgErr() + return nil, c.ArgErr() } diff --git a/plugin/loadbalance/setup_test.go b/plugin/loadbalance/setup_test.go index 38cea1478..4e3c99cec 100644 --- a/plugin/loadbalance/setup_test.go +++ b/plugin/loadbalance/setup_test.go @@ -7,24 +7,53 @@ import ( "github.com/coredns/caddy" ) +// weighted round robin specific test data +var testWeighted = []struct { + expectedWeightFile string + expectedWeightReload string +}{ + {"wfile", "30s"}, + {"wf", "10s"}, + {"wf", "0s"}, +} + func TestSetup(t *testing.T) { tests := []struct { input string shouldErr bool expectedPolicy string expectedErrContent string // substring from the expected error. Empty for positive cases. + weightedDataIndex int // weighted round robin specific data index }{ // positive - {`loadbalance`, false, "round_robin", ""}, - {`loadbalance round_robin`, false, "round_robin", ""}, + {`loadbalance`, false, "round_robin", "", -1}, + {`loadbalance round_robin`, false, "round_robin", "", -1}, + {`loadbalance weighted wfile`, false, "weighted", "", 0}, + {`loadbalance weighted wf { + reload 10s + } `, false, "weighted", "", 1}, + {`loadbalance weighted wf { + reload 0s + } `, false, "weighted", "", 2}, // negative - {`loadbalance fleeb`, true, "", "unknown policy"}, - {`loadbalance a b`, true, "", "argument count or unexpected line"}, + {`loadbalance fleeb`, true, "", "unknown policy", -1}, + {`loadbalance round_robin a`, true, "", "unknown property", -1}, + {`loadbalance weighted`, true, "", "missing weight file argument", -1}, + {`loadbalance weighted a b`, true, "", "unexpected argument", -1}, + {`loadbalance weighted wfile { + susu + } `, true, "", "unknown property", -1}, + {`loadbalance weighted wfile { + reload a + } `, true, "", "invalid reload duration", -1}, + {`loadbalance weighted wfile { + reload 30s a + } `, true, "", "unexpected argument", -1}, } for i, test := range tests { c := caddy.NewTestController("dns", test.input) - err := parse(c) + lb, err := parse(c) if test.shouldErr && err == nil { t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) @@ -32,11 +61,39 @@ func TestSetup(t *testing.T) { if err != nil { if !test.shouldErr { - t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", + i, test.input, err) } if !strings.Contains(err.Error(), test.expectedErrContent) { - t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", + i, test.expectedErrContent, err, test.input) + } + continue + } + + if lb == nil { + t.Errorf("Test %d: Expected valid loadbalance funcs but got nil for input %s", + i, test.input) + continue + } + policy := ramdomShufflePolicy + if lb.weighted != nil { + policy = weightedRoundRobinPolicy + } + if policy != test.expectedPolicy { + t.Errorf("Test %d: Expected policy %s but got %s for input %s", i, + test.expectedPolicy, policy, test.input) + } + if policy == weightedRoundRobinPolicy && test.weightedDataIndex >= 0 { + i := test.weightedDataIndex + if testWeighted[i].expectedWeightFile != lb.weighted.fileName { + t.Errorf("Test %d: Expected weight file name %s but got %s for input %s", + i, testWeighted[i].expectedWeightFile, lb.weighted.fileName, test.input) + } + if testWeighted[i].expectedWeightReload != lb.weighted.reload.String() { + t.Errorf("Test %d: Expected weight reload duration %s but got %s for input %s", + i, testWeighted[i].expectedWeightReload, lb.weighted.reload, test.input) } } } diff --git a/plugin/loadbalance/weighted.go b/plugin/loadbalance/weighted.go new file mode 100644 index 000000000..44aef63e8 --- /dev/null +++ b/plugin/loadbalance/weighted.go @@ -0,0 +1,329 @@ +package loadbalance + +import ( + "bufio" + "bytes" + "crypto/md5" + "errors" + "fmt" + "io" + "math/rand" + "net" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +type ( + // "weighted-round-robin" policy specific data + weightedRR struct { + fileName string + reload time.Duration + md5sum [md5.Size]byte + domains map[string]weights + randomGen + mutex sync.Mutex + } + // Per domain weights + weights []*weightItem + // Weight assigned to an address + weightItem struct { + address net.IP + value uint8 + } + // Random uint generator + randomGen interface { + randInit() + randUint(limit uint) uint + } +) + +// Random uint generator +type randomUint struct { + rn *rand.Rand +} + +func (r *randomUint) randInit() { + r.rn = rand.New(rand.NewSource(time.Now().UnixNano())) +} + +func (r *randomUint) randUint(limit uint) uint { + return uint(r.rn.Intn(int(limit))) +} + +func weightedShuffle(res *dns.Msg, w *weightedRR) *dns.Msg { + switch res.Question[0].Qtype { + case dns.TypeA, dns.TypeAAAA, dns.TypeSRV: + res.Answer = w.weightedRoundRobin(res.Answer) + res.Extra = w.weightedRoundRobin(res.Extra) + } + return res +} + +func weightedOnStartUp(w *weightedRR, stopReloadChan chan bool) error { + err := w.updateWeights() + if errors.Is(err, errOpen) && w.reload != 0 { + log.Warningf("Failed to open weight file:%v. Will try again in %v", + err, w.reload) + } else if err != nil { + return plugin.Error("loadbalance", err) + } + // start periodic weight file reload go routine + w.periodicWeightUpdate(stopReloadChan) + return nil +} + +func createWeightedFuncs(weightFileName string, + reload time.Duration) *lbFuncs { + lb := &lbFuncs{ + weighted: &weightedRR{ + fileName: weightFileName, + reload: reload, + randomGen: &randomUint{}, + }, + } + lb.weighted.randomGen.randInit() + + lb.shuffleFunc = func(res *dns.Msg) *dns.Msg { + return weightedShuffle(res, lb.weighted) + } + + stopReloadChan := make(chan bool) + + lb.onStartUpFunc = func() error { + return weightedOnStartUp(lb.weighted, stopReloadChan) + } + + lb.onShutdownFunc = func() error { + // stop periodic weigh reload go routine + close(stopReloadChan) + return nil + } + return lb +} + +// Apply weighted round robin policy to the answer +func (w *weightedRR) weightedRoundRobin(in []dns.RR) []dns.RR { + cname := []dns.RR{} + address := []dns.RR{} + mx := []dns.RR{} + rest := []dns.RR{} + for _, r := range in { + switch r.Header().Rrtype { + case dns.TypeCNAME: + cname = append(cname, r) + case dns.TypeA, dns.TypeAAAA: + address = append(address, r) + case dns.TypeMX: + mx = append(mx, r) + default: + rest = append(rest, r) + } + } + + if len(address) == 0 { + // no change + return in + } + + w.setTopRecord(address) + + out := append(cname, rest...) + out = append(out, address...) + out = append(out, mx...) + return out +} + +// Move the next expected address to the first position in the result list +func (w *weightedRR) setTopRecord(address []dns.RR) { + itop := w.topAddressIndex(address) + + if itop < 0 { + // internal error + return + } + + if itop != 0 { + // swap the selected top entry with the actual one + address[0], address[itop] = address[itop], address[0] + } +} + +// Compute the top (first) address index +func (w *weightedRR) topAddressIndex(address []dns.RR) int { + w.mutex.Lock() + defer w.mutex.Unlock() + + // Dertermine the weight value for each address in the answer + var wsum uint + type waddress struct { + index int + weight uint8 + } + weightedAddr := make([]waddress, len(address)) + for i, ar := range address { + wa := &weightedAddr[i] + wa.index = i + wa.weight = 1 // default weight + var ip net.IP + switch ar.Header().Rrtype { + case dns.TypeA: + ip = ar.(*dns.A).A + case dns.TypeAAAA: + ip = ar.(*dns.AAAA).AAAA + } + ws := w.domains[ar.Header().Name] + for _, w := range ws { + if w.address.Equal(ip) { + wa.weight = w.value + break + } + } + wsum += uint(wa.weight) + } + + // Select the first (top) IP + sort.Slice(weightedAddr, func(i, j int) bool { + return weightedAddr[i].weight > weightedAddr[j].weight + }) + v := w.randUint(wsum) + var psum uint + for _, wa := range weightedAddr { + psum += uint(wa.weight) + if v < psum { + return int(wa.index) + } + } + + // we should never reach this + log.Errorf("Internal error: cannot find top addres (randv:%v wsum:%v)", v, wsum) + return -1 +} + +// Start go routine to update weights from the weight file periodically +func (w *weightedRR) periodicWeightUpdate(stopReload <-chan bool) { + if w.reload == 0 { + return + } + + go func() { + ticker := time.NewTicker(w.reload) + for { + select { + case <-stopReload: + return + case <-ticker.C: + err := w.updateWeights() + if err != nil { + log.Error(err) + } + } + } + }() +} + +// Update weights from weight file +func (w *weightedRR) updateWeights() error { + reader, err := os.Open(filepath.Clean(w.fileName)) + if err != nil { + return errOpen + } + defer reader.Close() + + // check if the contents has changed + var buf bytes.Buffer + tee := io.TeeReader(reader, &buf) + bytes, err := io.ReadAll(tee) + if err != nil { + return err + } + md5sum := md5.Sum(bytes) + if md5sum == w.md5sum { + // file contents has not changed + return nil + } + w.md5sum = md5sum + scanner := bufio.NewScanner(&buf) + + // Parse the weight file contents + err = w.parseWeights(scanner) + if err != nil { + return err + } + + log.Infof("Successfully reloaded weight file %s", w.fileName) + return nil +} + +// Parse the weight file contents +func (w *weightedRR) parseWeights(scanner *bufio.Scanner) error { + // access to weights must be protected + w.mutex.Lock() + defer w.mutex.Unlock() + + // Reset domains + w.domains = make(map[string]weights) + + var dname string + var ws weights + for scanner.Scan() { + nextLine := strings.TrimSpace(scanner.Text()) + if len(nextLine) == 0 || nextLine[0:1] == "#" { + // Empty and comment lines are ignored + continue + } + fields := strings.Fields(nextLine) + switch len(fields) { + case 1: + // (domain) name sanity check + if net.ParseIP(fields[0]) != nil { + return fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)", + fields[0], w.fileName) + } + dname = fields[0] + + // add the root domain if it is missing + if dname[len(dname)-1] != '.' { + dname += "." + } + var ok bool + ws, ok = w.domains[dname] + if !ok { + ws = make(weights, 0) + w.domains[dname] = ws + } + case 2: + // IP address and weight value + ip := net.ParseIP(fields[0]) + if ip == nil { + return fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName) + } + weight, err := strconv.ParseUint(fields[1], 10, 8) + if err != nil { + return fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName) + } + witem := &weightItem{address: ip, value: uint8(weight)} + if dname == "" { + return fmt.Errorf("Missing domain name in weight file %s", w.fileName) + } + ws = append(ws, witem) + w.domains[dname] = ws + default: + return fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName) + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err) + } + + return nil +} diff --git a/plugin/loadbalance/weighted_test.go b/plugin/loadbalance/weighted_test.go new file mode 100644 index 000000000..e502c2772 --- /dev/null +++ b/plugin/loadbalance/weighted_test.go @@ -0,0 +1,424 @@ +package loadbalance + +import ( + "context" + "errors" + "net" + "strings" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + testutil "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +const oneDomainWRR = ` +w1,example.org +192.168.1.15 10 +192.168.1.14 20 +` + +var testOneDomainWRR = map[string]weights{ + "w1,example.org.": weights{ + &weightItem{net.ParseIP("192.168.1.15"), uint8(10)}, + &weightItem{net.ParseIP("192.168.1.14"), uint8(20)}, + }, +} + +const twoDomainsWRR = ` +# domain 1 +w1.example.org +192.168.1.15 10 +192.168.1.14 20 + +# domain 2 +w2.example.org + # domain 3 + w3.example.org + 192.168.2.16 11 + 192.168.2.15 12 + 192.168.2.14 13 +` + +var testTwoDomainsWRR = map[string]weights{ + "w1.example.org.": weights{ + &weightItem{net.ParseIP("192.168.1.15"), uint8(10)}, + &weightItem{net.ParseIP("192.168.1.14"), uint8(20)}, + }, + "w2.example.org.": weights{}, + "w3.example.org.": weights{ + &weightItem{net.ParseIP("192.168.2.16"), uint8(11)}, + &weightItem{net.ParseIP("192.168.2.15"), uint8(12)}, + &weightItem{net.ParseIP("192.168.2.14"), uint8(13)}, + }, +} + +const missingWeightWRR = ` +w1,example.org +192.168.1.14 +192.168.1.15 20 +` + +const missingDomainWRR = ` +# missing domain +192.168.1.14 10 +w2,example.org +192.168.2.14 11 +192.168.2.15 12 +` + +const wrongIpWRR = ` +w1,example.org +192.168.1.300 10 +` + +const wrongWeightWRR = ` +w1,example.org +192.168.1.14 300 +` + +func TestWeightFileUpdate(t *testing.T) { + tests := []struct { + weightFilContent string + shouldErr bool + expectedDomains map[string]weights + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + {"", false, nil, ""}, + {oneDomainWRR, false, testOneDomainWRR, ""}, + {twoDomainsWRR, false, testTwoDomainsWRR, ""}, + // negative + {missingWeightWRR, true, nil, "Wrong domain name"}, + {missingDomainWRR, true, nil, "Missing domain name"}, + {wrongIpWRR, true, nil, "Wrong IP address"}, + {wrongWeightWRR, true, nil, "Wrong weight value"}, + } + + for i, test := range tests { + testFile, rm, err := testutil.TempFile(".", test.weightFilContent) + if err != nil { + t.Fatal(err) + } + defer rm() + weighted := &weightedRR{fileName: testFile} + err = weighted.updateWeights() + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s", i, err) + } + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found error: %v", i, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v", + i, test.expectedErrContent, err) + } + } + if test.expectedDomains != nil { + if len(test.expectedDomains) != len(weighted.domains) { + t.Errorf("Test %d: Expected len(domains): %d but got %d", + i, len(test.expectedDomains), len(weighted.domains)) + } else { + _ = checkDomainsWRR(t, i, test.expectedDomains, weighted.domains) + } + } + } +} + +func checkDomainsWRR(t *testing.T, testIndex int, expectedDomains, domains map[string]weights) error { + var ret error + retError := errors.New("Check domains failed") + for dname, expectedWeights := range expectedDomains { + ws, ok := domains[dname] + if !ok { + t.Errorf("Test %d: Expected domain %s but not found it", testIndex, dname) + ret = retError + } else { + if len(expectedWeights) != len(ws) { + t.Errorf("Test %d: Expected len(weights): %d for domain %s but got %d", + testIndex, len(expectedWeights), dname, len(ws)) + ret = retError + } else { + for i, w := range expectedWeights { + if !w.address.Equal(ws[i].address) || w.value != ws[i].value { + t.Errorf("Test %d: Weight list differs at index %d for domain %s. "+ + "Expected: %v got: %v", testIndex, i, dname, expectedWeights[i], ws[i]) + ret = retError + } + } + } + } + } + + return ret +} + +func TestPeriodicWeightUpdate(t *testing.T) { + testFile1, rm, err := testutil.TempFile(".", oneDomainWRR) + if err != nil { + t.Fatal(err) + } + defer rm() + testFile2, rm, err := testutil.TempFile(".", twoDomainsWRR) + if err != nil { + t.Fatal(err) + } + defer rm() + + // configure weightedRR with "oneDomainWRR" weight file content + weighted := &weightedRR{fileName: testFile1} + + err = weighted.updateWeights() + if err != nil { + t.Fatal(err) + } else { + err = checkDomainsWRR(t, 0, testOneDomainWRR, weighted.domains) + if err != nil { + t.Fatalf("Initial check domains failed") + } + } + + // change weight file + weighted.fileName = testFile2 + // start periodic update + weighted.reload = 10 * time.Millisecond + stopChan := make(chan bool) + weighted.periodicWeightUpdate(stopChan) + time.Sleep(20 * time.Millisecond) + // stop periodic update + close(stopChan) + // check updated config + weighted.mutex.Lock() + err = checkDomainsWRR(t, 0, testTwoDomainsWRR, weighted.domains) + weighted.mutex.Unlock() + if err != nil { + t.Fatalf("Final check domains failed") + } +} + +// Fake random number generator for testing +type fakeRandomGen struct { + expectedLimit uint + testIndex int + queryIndex int + randv uint + t *testing.T +} + +func (r *fakeRandomGen) randInit() { +} + +func (r *fakeRandomGen) randUint(limit uint) uint { + if limit != r.expectedLimit { + r.t.Errorf("Test %d query %d: Expected weights sum %d but got %d", + r.testIndex, r.queryIndex, r.expectedLimit, limit) + } + return r.randv +} + +func TestLoadBalanceWRR(t *testing.T) { + type testQuery struct { + randv uint // fake random value for selecting the top IP + topIP string // top (first) address record in the answer + } + + // domain maps to test + oneDomain := map[string]weights{ + "endpoint.region2.skydns.test.": weights{ + &weightItem{net.ParseIP("10.240.0.2"), uint8(3)}, + &weightItem{net.ParseIP("10.240.0.1"), uint8(2)}, + }, + } + twoDomains := map[string]weights{ + "endpoint.region2.skydns.test.": weights{ + &weightItem{net.ParseIP("10.240.0.2"), uint8(5)}, + &weightItem{net.ParseIP("10.240.0.1"), uint8(2)}, + }, + "endpoint.region1.skydns.test.": weights{ + &weightItem{net.ParseIP("::2"), uint8(4)}, + &weightItem{net.ParseIP("::1"), uint8(3)}, + }, + } + + // the first X records must be cnames after this test + tests := []struct { + answer []dns.RR + extra []dns.RR + cnameAnswer int + cnameExtra int + addressAnswer int + addressExtra int + mxAnswer int + mxExtra int + domains map[string]weights + sumWeights uint // sum of weights in the answer + queries []testQuery + }{ + { + answer: []dns.RR{ + testutil.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + testutil.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + testutil.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.MX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."), + testutil.MX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."), + }, + extra: []dns.RR{ + testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + }, + cnameAnswer: 4, + cnameExtra: 1, + addressAnswer: 5, + addressExtra: 5, + mxAnswer: 3, + mxExtra: 1, + domains: twoDomains, + sumWeights: 15, + queries: []testQuery{ + {0, "10.240.0.2"}, // domain 1 weight 5 + {4, "10.240.0.2"}, // domain 1 weight 5 + {5, "::2"}, // domain 2 weight 4 + {8, "::2"}, // domain 2 weight 4 + {9, "::1"}, // domain 2 weight 3 + {11, "::1"}, // domain 2 weight 3 + {12, "10.240.0.1"}, // domain 1 weight 2 + {13, "10.240.0.1"}, // domain 1 weight 2 + {14, "10.240.0.3"}, // domain 1 no weight -> default weight + }, + }, + { + answer: []dns.RR{ + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region1.skydns.test. 300 IN A 10.240.0.3"), + }, + cnameAnswer: 1, + addressAnswer: 3, + mxAnswer: 1, + domains: oneDomain, + sumWeights: 6, + queries: []testQuery{ + {0, "10.240.0.2"}, // weight 3 + {2, "10.240.0.2"}, // weight 3 + {3, "10.240.0.1"}, // weight 2 + {4, "10.240.0.1"}, // weight 2 + {5, "10.240.0.3"}, // no domain -> default weight + }, + }, + { + answer: []dns.RR{ + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + }, + cnameAnswer: 1, + mxAnswer: 1, + domains: oneDomain, + queries: []testQuery{ + {0, ""}, // no address records -> answer unaltered + }, + }, + } + + testRand := &fakeRandomGen{t: t} + weighted := &weightedRR{randomGen: testRand} + shuffle := func(res *dns.Msg) *dns.Msg { + return weightedShuffle(res, weighted) + } + rm := LoadBalance{Next: handler(), shuffle: shuffle} + + rec := dnstest.NewRecorder(&testutil.ResponseWriter{}) + + for i, test := range tests { + // set domain map for weighted round robin + weighted.domains = test.domains + testRand.testIndex = i + testRand.expectedLimit = test.sumWeights + + for j, query := range test.queries { + req := new(dns.Msg) + req.SetQuestion("endpoint.region2.skydns.test", dns.TypeSRV) + req.Answer = test.answer + req.Extra = test.extra + + // Set fake random number + testRand.randv = query.randv + testRand.queryIndex = j + + _, err := rm.ServeDNS(context.TODO(), rec, req) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + } + + checkTopIP(t, i, j, rec.Msg.Answer, query.topIP) + checkTopIP(t, i, j, rec.Msg.Extra, query.topIP) + + cname, address, mx, sorted := countRecords(rec.Msg.Answer) + if query.topIP != "" && !sorted { + t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j) + } + if cname != test.cnameAnswer { + t.Errorf("Test %d query %d: Expected %d CNAMEs in Answer, but got %d", i, j, test.cnameAnswer, cname) + } + if address != test.addressAnswer { + t.Errorf("Test %d query %d: Expected %d A/AAAAs in Answer, but got %d", i, j, test.addressAnswer, address) + } + if mx != test.mxAnswer { + t.Errorf("Test %d query %d: Expected %d MXs in Answer, but got %d", i, j, test.mxAnswer, mx) + } + + cname, address, mx, sorted = countRecords(rec.Msg.Extra) + if query.topIP != "" && !sorted { + t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j) + } + + if cname != test.cnameExtra { + t.Errorf("Test %d query %d: Expected %d CNAMEs in Extra, but got %d", i, j, test.cnameAnswer, cname) + } + if address != test.addressExtra { + t.Errorf("Test %d query %d: Expected %d A/AAAAs in Extra, but got %d", i, j, test.addressAnswer, address) + } + if mx != test.mxExtra { + t.Errorf("Test %d query %d: Expected %d MXs in Extra, but got %d", i, j, test.mxAnswer, mx) + } + } + } +} + +func checkTopIP(t *testing.T, i, j int, result []dns.RR, expectedTopIP string) { + expected := net.ParseIP(expectedTopIP) + for _, r := range result { + switch r.Header().Rrtype { + case dns.TypeA: + ar := r.(*dns.A) + if !ar.A.Equal(expected) { + t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.A) + } + return + case dns.TypeAAAA: + ar := r.(*dns.AAAA) + if !ar.AAAA.Equal(expected) { + t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.AAAA) + } + return + } + } +} |