diff options
Diffstat (limited to 'middleware/proxy/upstream_test.go')
-rw-r--r-- | middleware/proxy/upstream_test.go | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go index c48b4446b..bbc51f59a 100644 --- a/middleware/proxy/upstream_test.go +++ b/middleware/proxy/upstream_test.go @@ -1,6 +1,10 @@ package proxy import ( + "io/ioutil" + "os" + "path/filepath" + "strings" "testing" "time" @@ -78,6 +82,19 @@ func TestAllowedPaths(t *testing.T) { } } +func writeTmpFile(t *testing.T, data string) (string, string) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("tempDir: %v", err) + } + + path := filepath.Join(tempDir, "resolv.conf") + if err := ioutil.WriteFile(path, []byte(data), 0644); err != nil { + t.Fatalf("writeFile: %v", err) + } + return tempDir, path +} + func TestProxyParse(t *testing.T) { tests := []struct { inputUpstreams string @@ -143,6 +160,11 @@ proxy . 8.8.8.8:53 { }`, true, }, + { + ` +proxy . some_bogus_filename`, + true, + }, } for i, test := range tests { c := caddy.NewTestController("dns", test.inputUpstreams) @@ -152,3 +174,79 @@ proxy . 8.8.8.8:53 { } } } + +func TestResolvParse(t *testing.T) { + tests := []struct { + inputUpstreams string + filedata string + shouldErr bool + expected []string + }{ + { + ` +proxy . FILE +`, + ` +nameserver 1.2.3.4 +nameserver 4.3.2.1 +`, + false, + []string{"1.2.3.4:53", "4.3.2.1:53"}, + }, + { + ` +proxy example.com 1.1.1.1:5000 +proxy . FILE +proxy example.org 2.2.2.2:1234 +`, + ` +nameserver 1.2.3.4 +`, + false, + []string{"1.1.1.1:5000", "1.2.3.4:53", "2.2.2.2:1234"}, + }, + { + ` +proxy example.com 1.1.1.1:5000 +proxy . FILE +proxy example.org 2.2.2.2:1234 +`, + ` +junky resolve.conf +`, + false, + []string{"1.1.1.1:5000", "2.2.2.2:1234"}, + }, + } + for i, test := range tests { + tempDir, path := writeTmpFile(t, test.filedata) + defer os.RemoveAll(tempDir) + config := strings.Replace(test.inputUpstreams, "FILE", path, -1) + c := caddy.NewTestController("dns", config) + upstreams, err := NewStaticUpstreams(&c.Dispenser) + if (err != nil) != test.shouldErr { + t.Errorf("Test %d expected no error, got %v", i+1, err) + } + var hosts []string + for _, u := range upstreams { + for _, h := range u.(*staticUpstream).Hosts { + hosts = append(hosts, h.Name) + } + } + if !test.shouldErr { + if len(hosts) != len(test.expected) { + t.Errorf("Test %d expected %d hosts got %d", i+1, len(test.expected), len(upstreams)) + } else { + ok := true + for i, v := range test.expected { + if v != hosts[i] { + ok = false + } + } + if !ok { + t.Errorf("Test %d expected %v got %v", i+1, test.expected, upstreams) + } + } + } + } +} |