diff options
Diffstat (limited to 'middleware')
40 files changed, 4422 insertions, 0 deletions
diff --git a/middleware/commands.go b/middleware/commands.go new file mode 100644 index 000000000..5c241161e --- /dev/null +++ b/middleware/commands.go @@ -0,0 +1,120 @@ +package middleware + +import ( + "errors" + "runtime" + "unicode" + + "github.com/flynn/go-shlex" +) + +var runtimeGoos = runtime.GOOS + +// SplitCommandAndArgs takes a command string and parses it +// shell-style into the command and its separate arguments. +func SplitCommandAndArgs(command string) (cmd string, args []string, err error) { + var parts []string + + if runtimeGoos == "windows" { + parts = parseWindowsCommand(command) // parse it Windows-style + } else { + parts, err = parseUnixCommand(command) // parse it Unix-style + if err != nil { + err = errors.New("error parsing command: " + err.Error()) + return + } + } + + if len(parts) == 0 { + err = errors.New("no command contained in '" + command + "'") + return + } + + cmd = parts[0] + if len(parts) > 1 { + args = parts[1:] + } + + return +} + +// parseUnixCommand parses a unix style command line and returns the +// command and its arguments or an error +func parseUnixCommand(cmd string) ([]string, error) { + return shlex.Split(cmd) +} + +// parseWindowsCommand parses windows command lines and +// returns the command and the arguments as an array. It +// should be able to parse commonly used command lines. +// Only basic syntax is supported: +// - spaces in double quotes are not token delimiters +// - double quotes are escaped by either backspace or another double quote +// - except for the above case backspaces are path separators (not special) +// +// Many sources point out that escaping quotes using backslash can be unsafe. +// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 ) +// +// This function has to be used on Windows instead +// of the shlex package because this function treats backslash +// characters properly. +func parseWindowsCommand(cmd string) []string { + const backslash = '\\' + const quote = '"' + + var parts []string + var part string + var inQuotes bool + var lastRune rune + + for i, ch := range cmd { + + if i != 0 { + lastRune = rune(cmd[i-1]) + } + + if ch == backslash { + // put it in the part - for now we don't know if it's an + // escaping char or path separator + part += string(ch) + continue + } + + if ch == quote { + if lastRune == backslash { + // remove the backslash from the part and add the escaped quote instead + part = part[:len(part)-1] + part += string(ch) + continue + } + + if lastRune == quote { + // revert the last change of the inQuotes state + // it was an escaping quote + inQuotes = !inQuotes + part += string(ch) + continue + } + + // normal escaping quotes + inQuotes = !inQuotes + continue + + } + + if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 { + parts = append(parts, part) + part = "" + continue + } + + part += string(ch) + } + + if len(part) > 0 { + parts = append(parts, part) + part = "" + } + + return parts +} diff --git a/middleware/commands_test.go b/middleware/commands_test.go new file mode 100644 index 000000000..3001e65a5 --- /dev/null +++ b/middleware/commands_test.go @@ -0,0 +1,291 @@ +package middleware + +import ( + "fmt" + "runtime" + "strings" + "testing" +) + +func TestParseUnixCommand(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + // 0 - emtpy command + { + input: ``, + expected: []string{}, + }, + // 1 - command without arguments + { + input: `command`, + expected: []string{`command`}, + }, + // 2 - command with single argument + { + input: `command arg1`, + expected: []string{`command`, `arg1`}, + }, + // 3 - command with multiple arguments + { + input: `command arg1 arg2`, + expected: []string{`command`, `arg1`, `arg2`}, + }, + // 4 - command with single argument with space character - in quotes + { + input: `command "arg1 arg1"`, + expected: []string{`command`, `arg1 arg1`}, + }, + // 5 - command with multiple spaces and tab character + { + input: "command arg1 arg2\targ3", + expected: []string{`command`, `arg1`, `arg2`, `arg3`}, + }, + // 6 - command with single argument with space character - escaped with backspace + { + input: `command arg1\ arg2`, + expected: []string{`command`, `arg1 arg2`}, + }, + // 7 - single quotes should escape special chars + { + input: `command 'arg1\ arg2'`, + expected: []string{`command`, `arg1\ arg2`}, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + actual, _ := parseUnixCommand(test.input) + if len(actual) != len(test.expected) { + t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual) + continue + } + for j := 0; j < len(actual); j++ { + if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart { + t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j) + } + } + } +} + +func TestParseWindowsCommand(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + { // 0 - empty command - do not fail + input: ``, + expected: []string{}, + }, + { // 1 - cmd without args + input: `cmd`, + expected: []string{`cmd`}, + }, + { // 2 - multiple args + input: `cmd arg1 arg2`, + expected: []string{`cmd`, `arg1`, `arg2`}, + }, + { // 3 - multiple args with space + input: `cmd "combined arg" arg2`, + expected: []string{`cmd`, `combined arg`, `arg2`}, + }, + { // 4 - path without spaces + input: `mkdir C:\Windows\foo\bar`, + expected: []string{`mkdir`, `C:\Windows\foo\bar`}, + }, + { // 5 - command with space in quotes + input: `"command here"`, + expected: []string{`command here`}, + }, + { // 6 - argument with escaped quotes (two quotes) + input: `cmd ""arg""`, + expected: []string{`cmd`, `"arg"`}, + }, + { // 7 - argument with escaped quotes (backslash) + input: `cmd \"arg\"`, + expected: []string{`cmd`, `"arg"`}, + }, + { // 8 - two quotes (escaped) inside an inQuote element + input: `cmd "a ""quoted value"`, + expected: []string{`cmd`, `a "quoted value`}, + }, + // TODO - see how many quotes are dislayed if we use "", """, """"""" + { // 9 - two quotes outside an inQuote element + input: `cmd a ""quoted value`, + expected: []string{`cmd`, `a`, `"quoted`, `value`}, + }, + { // 10 - path with space in quotes + input: `mkdir "C:\directory name\foobar"`, + expected: []string{`mkdir`, `C:\directory name\foobar`}, + }, + { // 11 - space without quotes + input: `mkdir C:\ space`, + expected: []string{`mkdir`, `C:\`, `space`}, + }, + { // 12 - space in quotes + input: `mkdir "C:\ space"`, + expected: []string{`mkdir`, `C:\ space`}, + }, + { // 13 - UNC + input: `mkdir \\?\C:\Users`, + expected: []string{`mkdir`, `\\?\C:\Users`}, + }, + { // 14 - UNC with space + input: `mkdir "\\?\C:\Program Files"`, + expected: []string{`mkdir`, `\\?\C:\Program Files`}, + }, + + { // 15 - unclosed quotes - treat as if the path ends with quote + input: `mkdir "c:\Program files`, + expected: []string{`mkdir`, `c:\Program files`}, + }, + { // 16 - quotes used inside the argument + input: `mkdir "c:\P"rogra"m f"iles`, + expected: []string{`mkdir`, `c:\Program files`}, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + + actual := parseWindowsCommand(test.input) + if len(actual) != len(test.expected) { + t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual) + continue + } + for j := 0; j < len(actual); j++ { + if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart { + t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j) + } + } + } +} + +func TestSplitCommandAndArgs(t *testing.T) { + + // force linux parsing. It's more robust and covers error cases + runtimeGoos = "linux" + defer func() { + runtimeGoos = runtime.GOOS + }() + + var parseErrorContent = "error parsing command:" + var noCommandErrContent = "no command contained in" + + tests := []struct { + input string + expectedCommand string + expectedArgs []string + expectedErrContent string + }{ + // 0 - emtpy command + { + input: ``, + expectedCommand: ``, + expectedArgs: nil, + expectedErrContent: noCommandErrContent, + }, + // 1 - command without arguments + { + input: `command`, + expectedCommand: `command`, + expectedArgs: nil, + expectedErrContent: ``, + }, + // 2 - command with single argument + { + input: `command arg1`, + expectedCommand: `command`, + expectedArgs: []string{`arg1`}, + expectedErrContent: ``, + }, + // 3 - command with multiple arguments + { + input: `command arg1 arg2`, + expectedCommand: `command`, + expectedArgs: []string{`arg1`, `arg2`}, + expectedErrContent: ``, + }, + // 4 - command with unclosed quotes + { + input: `command "arg1 arg2`, + expectedCommand: "", + expectedArgs: nil, + expectedErrContent: parseErrorContent, + }, + // 5 - command with unclosed quotes + { + input: `command 'arg1 arg2"`, + expectedCommand: "", + expectedArgs: nil, + expectedErrContent: parseErrorContent, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + actualCommand, actualArgs, actualErr := SplitCommandAndArgs(test.input) + + // test if error matches expectation + if test.expectedErrContent != "" { + if actualErr == nil { + t.Errorf(errorPrefix+"Expected error with content [%s], found no error."+errorSuffix, test.expectedErrContent) + } else if !strings.Contains(actualErr.Error(), test.expectedErrContent) { + t.Errorf(errorPrefix+"Expected error with content [%s], found [%v]."+errorSuffix, test.expectedErrContent, actualErr) + } + } else if actualErr != nil { + t.Errorf(errorPrefix+"Expected no error, found [%v]."+errorSuffix, actualErr) + } + + // test if command matches + if test.expectedCommand != actualCommand { + t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand) + } + + // test if arguments match + if len(test.expectedArgs) != len(actualArgs) { + t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs) + } else { + // test args only if the count matches. + for j, actualArg := range actualArgs { + expectedArg := test.expectedArgs[j] + if actualArg != expectedArg { + t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg) + } + } + } + } +} + +func ExampleSplitCommandAndArgs() { + var commandLine string + var command string + var args []string + + // just for the test - change GOOS and reset it at the end of the test + runtimeGoos = "windows" + defer func() { + runtimeGoos = runtime.GOOS + }() + + commandLine = `mkdir /P "C:\Program Files"` + command, args, _ = SplitCommandAndArgs(commandLine) + + fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ",")) + + // set GOOS to linux + runtimeGoos = "linux" + + commandLine = `mkdir -p /path/with\ space` + command, args, _ = SplitCommandAndArgs(commandLine) + + fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ",")) + + // Output: + // Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files] + // Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space] +} diff --git a/middleware/context.go b/middleware/context.go new file mode 100644 index 000000000..8868c1c03 --- /dev/null +++ b/middleware/context.go @@ -0,0 +1,135 @@ +package middleware + +import ( + "net" + "net/http" + "strings" + "time" + + "github.com/miekg/dns" +) + +// This file contains the context and functions available for +// use in the templates. + +// Context is the context with which Caddy templates are executed. +type Context struct { + Root http.FileSystem // TODO(miek): needed + Req *dns.Msg + W dns.ResponseWriter +} + +// Now returns the current timestamp in the specified format. +func (c Context) Now(format string) string { + return time.Now().Format(format) +} + +// NowDate returns the current date/time that can be used +// in other time functions. +func (c Context) NowDate() time.Time { + return time.Now() +} + +// Header gets the value of a header. +func (c Context) Header() *dns.RR_Header { + // TODO(miek) + return nil +} + +// IP gets the (remote) IP address of the client making the request. +func (c Context) IP() string { + ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String()) + if err != nil { + return c.W.RemoteAddr().String() + } + return ip +} + +// Post gets the (remote) Port of the client making the request. +func (c Context) Port() (string, error) { + _, port, err := net.SplitHostPort(c.W.RemoteAddr().String()) + if err != nil { + return "0", err + } + return port, nil +} + +// Proto gets the protocol used as the transport. This +// will be udp or tcp. +func (c Context) Proto() string { + if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok { + return "udp" + } + if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok { + return "tcp" + } + return "udp" +} + +// Family returns the family of the transport. +// 1 for IPv4 and 2 for IPv6. +func (c Context) Family() int { + var a net.IP + ip := c.W.RemoteAddr() + if i, ok := ip.(*net.UDPAddr); ok { + a = i.IP + } + if i, ok := ip.(*net.TCPAddr); ok { + a = i.IP + } + + if a.To4() != nil { + return 1 + } + return 2 +} + +// Type returns the type of the question as a string. +func (c Context) Type() string { + return dns.Type(c.Req.Question[0].Qtype).String() +} + +// QType returns the type of the question as a uint16. +func (c Context) QType() uint16 { + return c.Req.Question[0].Qtype +} + +// Name returns the name of the question in the request. Note +// this name will always have a closing dot and will be lower cased. +func (c Context) Name() string { + return strings.ToLower(dns.Name(c.Req.Question[0].Name).String()) +} + +// QName returns the name of the question in the request. +func (c Context) QName() string { + return dns.Name(c.Req.Question[0].Name).String() +} + +// Class returns the class of the question in the request. +func (c Context) Class() string { + return dns.Class(c.Req.Question[0].Qclass).String() +} + +// QClass returns the class of the question in the request. +func (c Context) QClass() uint16 { + return c.Req.Question[0].Qclass +} + +// More convience types for extracting stuff from a message? +// Header? + +// ErrorMessage returns an error message suitable for sending +// back to the client. +func (c Context) ErrorMessage(rcode int) *dns.Msg { + m := new(dns.Msg) + m.SetRcode(c.Req, rcode) + return m +} + +// AnswerMessage returns an error message suitable for sending +// back to the client. +func (c Context) AnswerMessage() *dns.Msg { + m := new(dns.Msg) + m.SetReply(c.Req) + return m +} diff --git a/middleware/context_test.go b/middleware/context_test.go new file mode 100644 index 000000000..689c47c13 --- /dev/null +++ b/middleware/context_test.go @@ -0,0 +1,613 @@ +package middleware + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestInclude(t *testing.T) { + context := getContextOrFail(t) + + inputFilename := "test_file" + absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) + defer func() { + err := os.Remove(absInFilePath) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("Failed to clean test file!") + } + }() + + tests := []struct { + fileContent string + expectedContent string + shouldErr bool + expectedErrorContent string + }{ + // Test 0 - all good + { + fileContent: `str1 {{ .Root }} str2`, + expectedContent: fmt.Sprintf("str1 %s str2", context.Root), + shouldErr: false, + expectedErrorContent: "", + }, + // Test 1 - failure on template.Parse + { + fileContent: `str1 {{ .Root } str2`, + expectedContent: "", + shouldErr: true, + expectedErrorContent: `unexpected "}" in operand`, + }, + // Test 3 - failure on template.Execute + { + fileContent: `str1 {{ .InvalidField }} str2`, + expectedContent: "", + shouldErr: true, + expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`, + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // WriteFile truncates the contentt + err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) + if err != nil { + t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) + } + + content, err := context.Include(inputFilename) + if err != nil { + if !test.shouldErr { + t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error()) + } + if !strings.Contains(err.Error(), test.expectedErrorContent) { + t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error()) + } + } + + if err == nil && test.shouldErr { + t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename) + } + + if content != test.expectedContent { + t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) + } + } +} + +func TestIncludeNotExisting(t *testing.T) { + context := getContextOrFail(t) + + _, err := context.Include("not_existing") + if err == nil { + t.Errorf("Expected error but found nil!") + } +} + +func TestMarkdown(t *testing.T) { + context := getContextOrFail(t) + + inputFilename := "test_file" + absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) + defer func() { + err := os.Remove(absInFilePath) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("Failed to clean test file!") + } + }() + + tests := []struct { + fileContent string + expectedContent string + }{ + // Test 0 - test parsing of markdown + { + fileContent: "* str1\n* str2\n", + expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n", + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // WriteFile truncates the contentt + err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) + if err != nil { + t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) + } + + content, _ := context.Markdown(inputFilename) + if content != test.expectedContent { + t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) + } + } +} + +func TestCookie(t *testing.T) { + + tests := []struct { + cookie *http.Cookie + cookieName string + expectedValue string + }{ + // Test 0 - happy path + { + cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, + cookieName: "cookieName", + expectedValue: "cookieValue", + }, + // Test 1 - try to get a non-existing cookie + { + cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, + cookieName: "notExisting", + expectedValue: "", + }, + // Test 2 - partial name match + { + cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"}, + cookieName: "cook", + expectedValue: "", + }, + // Test 3 - cookie with optional fields + { + cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120}, + cookieName: "cookie", + expectedValue: "cookieValue", + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // reinitialize the context for each test + context := getContextOrFail(t) + + context.Req.AddCookie(test.cookie) + + actualCookieVal := context.Cookie(test.cookieName) + + if actualCookieVal != test.expectedValue { + t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName) + } + } +} + +func TestCookieMultipleCookies(t *testing.T) { + context := getContextOrFail(t) + + cookieNameBase, cookieValueBase := "cookieName", "cookieValue" + + // make sure that there's no state and multiple requests for different cookies return the correct result + for i := 0; i < 10; i++ { + context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)}) + } + + for i := 0; i < 10; i++ { + expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i) + actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i)) + if actualCookieVal != expectedCookieVal { + t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal) + } + } +} + +func TestHeader(t *testing.T) { + context := getContextOrFail(t) + + headerKey, headerVal := "Header1", "HeaderVal1" + context.Req.Header.Add(headerKey, headerVal) + + actualHeaderVal := context.Header(headerKey) + if actualHeaderVal != headerVal { + t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) + } + + missingHeaderVal := context.Header("not-existing") + if missingHeaderVal != "" { + t.Errorf("Expected empty header value, found %s", missingHeaderVal) + } +} + +func TestIP(t *testing.T) { + context := getContextOrFail(t) + + tests := []struct { + inputRemoteAddr string + expectedIP string + }{ + // Test 0 - ipv4 with port + {"1.1.1.1:1111", "1.1.1.1"}, + // Test 1 - ipv4 without port + {"1.1.1.1", "1.1.1.1"}, + // Test 2 - ipv6 with port + {"[::1]:11", "::1"}, + // Test 3 - ipv6 without port and brackets + {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, + // Test 4 - ipv6 with zone and port + {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + context.Req.RemoteAddr = test.inputRemoteAddr + actualIP := context.IP() + + if actualIP != test.expectedIP { + t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) + } + } +} + +func TestURL(t *testing.T) { + context := getContextOrFail(t) + + inputURL := "http://localhost" + context.Req.RequestURI = inputURL + + if inputURL != context.URI() { + t.Errorf("Expected url %s, found %s", inputURL, context.URI()) + } +} + +func TestHost(t *testing.T) { + tests := []struct { + input string + expectedHost string + shouldErr bool + }{ + { + input: "localhost:123", + expectedHost: "localhost", + shouldErr: false, + }, + { + input: "localhost", + expectedHost: "localhost", + shouldErr: false, + }, + { + input: "[::]", + expectedHost: "", + shouldErr: true, + }, + } + + for _, test := range tests { + testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) + } +} + +func TestPort(t *testing.T) { + tests := []struct { + input string + expectedPort string + shouldErr bool + }{ + { + input: "localhost:123", + expectedPort: "123", + shouldErr: false, + }, + { + input: "localhost", + expectedPort: "80", // assuming 80 is the default port + shouldErr: false, + }, + { + input: ":8080", + expectedPort: "8080", + shouldErr: false, + }, + { + input: "[::]", + expectedPort: "", + shouldErr: true, + }, + } + + for _, test := range tests { + testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) + } +} + +func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { + context := getContextOrFail(t) + + context.Req.Host = input + var actualResult, testedObject string + var err error + + if isTestingHost { + actualResult, err = context.Host() + testedObject = "host" + } else { + actualResult, err = context.Port() + testedObject = "port" + } + + if shouldErr && err == nil { + t.Errorf("Expected error, found nil!") + return + } + + if !shouldErr && err != nil { + t.Errorf("Expected no error, found %s", err) + return + } + + if actualResult != expectedResult { + t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) + } +} + +func TestMethod(t *testing.T) { + context := getContextOrFail(t) + + method := "POST" + context.Req.Method = method + + if method != context.Method() { + t.Errorf("Expected method %s, found %s", method, context.Method()) + } + +} + +func TestPathMatches(t *testing.T) { + context := getContextOrFail(t) + + tests := []struct { + urlStr string + pattern string + shouldMatch bool + }{ + // Test 0 + { + urlStr: "http://localhost/", + pattern: "", + shouldMatch: true, + }, + // Test 1 + { + urlStr: "http://localhost", + pattern: "", + shouldMatch: true, + }, + // Test 1 + { + urlStr: "http://localhost/", + pattern: "/", + shouldMatch: true, + }, + // Test 3 + { + urlStr: "http://localhost/?param=val", + pattern: "/", + shouldMatch: true, + }, + // Test 4 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "/dir2", + shouldMatch: false, + }, + // Test 5 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "/dir1", + shouldMatch: true, + }, + // Test 6 + { + urlStr: "http://localhost:444/dir1/dir2", + pattern: "/dir1", + shouldMatch: true, + }, + // Test 7 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "*/dir2", + shouldMatch: false, + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + var err error + context.Req.URL, err = url.Parse(test.urlStr) + if err != nil { + t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) + } + + matches := context.PathMatches(test.pattern) + if matches != test.shouldMatch { + t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) + } + } +} + +func TestTruncate(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + inputString string + inputLength int + expected string + }{ + // Test 0 - small length + { + inputString: "string", + inputLength: 1, + expected: "s", + }, + // Test 1 - exact length + { + inputString: "string", + inputLength: 6, + expected: "string", + }, + // Test 2 - bigger length + { + inputString: "string", + inputLength: 10, + expected: "string", + }, + // Test 3 - zero length + { + inputString: "string", + inputLength: 0, + expected: "", + }, + // Test 4 - negative, smaller length + { + inputString: "string", + inputLength: -5, + expected: "tring", + }, + // Test 5 - negative, exact length + { + inputString: "string", + inputLength: -6, + expected: "string", + }, + // Test 6 - negative, bigger length + { + inputString: "string", + inputLength: -7, + expected: "string", + }, + } + + for i, test := range tests { + actual := context.Truncate(test.inputString, test.inputLength) + if actual != test.expected { + t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength) + } + } +} + +func TestStripHTML(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + input string + expected string + }{ + // Test 0 - no tags + { + input: `h1`, + expected: `h1`, + }, + // Test 1 - happy path + { + input: `<h1>h1</h1>`, + expected: `h1`, + }, + // Test 2 - tag in quotes + { + input: `<h1">">h1</h1>`, + expected: `h1`, + }, + // Test 3 - multiple tags + { + input: `<h1><b>h1</b></h1>`, + expected: `h1`, + }, + // Test 4 - tags not closed + { + input: `<h1`, + expected: `<h1`, + }, + // Test 5 - false start + { + input: `<h1<b>hi`, + expected: `<h1hi`, + }, + } + + for i, test := range tests { + actual := context.StripHTML(test.input) + if actual != test.expected { + t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input) + } + } +} + +func TestStripExt(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + input string + expected string + }{ + // Test 0 - empty input + { + input: "", + expected: "", + }, + // Test 1 - relative file with ext + { + input: "file.ext", + expected: "file", + }, + // Test 2 - relative file without ext + { + input: "file", + expected: "file", + }, + // Test 3 - absolute file without ext + { + input: "/file", + expected: "/file", + }, + // Test 4 - absolute file with ext + { + input: "/file.ext", + expected: "/file", + }, + // Test 5 - with ext but ends with / + { + input: "/dir.ext/", + expected: "/dir.ext/", + }, + // Test 6 - file with ext under dir with ext + { + input: "/dir.ext/file.ext", + expected: "/dir.ext/file", + }, + } + + for i, test := range tests { + actual := context.StripExt(test.input) + if actual != test.expected { + t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input) + } + } +} + +func initTestContext() (Context, error) { + body := bytes.NewBufferString("request body") + request, err := http.NewRequest("GET", "https://localhost", body) + if err != nil { + return Context{}, err + } + + return Context{Root: http.Dir(os.TempDir()), Req: request}, nil +} + +func getContextOrFail(t *testing.T) Context { + context, err := initTestContext() + if err != nil { + t.Fatalf("Failed to prepare test context") + } + return context +} + +func getTestPrefix(testN int) string { + return fmt.Sprintf("Test [%d]: ", testN) +} diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go new file mode 100644 index 000000000..bf5bc7aae --- /dev/null +++ b/middleware/errors/errors.go @@ -0,0 +1,100 @@ +// Package errors implements an HTTP error handling middleware. +package errors + +import ( + "fmt" + "log" + "runtime" + "strings" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// ErrorHandler handles DNS errors (and errors from other middleware). +type ErrorHandler struct { + Next middleware.Handler + LogFile string + Log *log.Logger + LogRoller *middleware.LogRoller + Debug bool // if true, errors are written out to client rather than to a log +} + +func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + defer h.recovery(w, r) + + rcode, err := h.Next.ServeDNS(w, r) + + if err != nil { + errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err) + + if h.Debug { + // Write error to response as a txt message instead of to log + answer := debugMsg(rcode, r) + txt, _ := dns.NewRR(". IN 0 TXT " + errMsg) + answer.Answer = append(answer.Answer, txt) + w.WriteMsg(answer) + return 0, err + } + h.Log.Println(errMsg) + } + + return rcode, err +} + +func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) { + rec := recover() + if rec == nil { + return + } + + // Obtain source of panic + // From: https://gist.github.com/swdunlop/9629168 + var name, file string // function name, file name + var line int + var pc [16]uintptr + n := runtime.Callers(3, pc[:]) + for _, pc := range pc[:n] { + fn := runtime.FuncForPC(pc) + if fn == nil { + continue + } + file, line = fn.FileLine(pc) + name = fn.Name() + if !strings.HasPrefix(name, "runtime.") { + break + } + } + + // Trim file path + delim := "/coredns/" + pkgPathPos := strings.Index(file, delim) + if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) { + file = file[pkgPathPos+len(delim):] + } + + panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec) + if h.Debug { + // Write error and stack trace to the response rather than to a log + var stackBuf [4096]byte + stack := stackBuf[:runtime.Stack(stackBuf[:], false)] + answer := debugMsg(dns.RcodeServerFailure, r) + // add stack buf in TXT, limited to 255 chars for now. + txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255])) + answer.Answer = append(answer.Answer, txt) + w.WriteMsg(answer) + } else { + // Currently we don't use the function name, since file:line is more conventional + h.Log.Printf(panicMsg) + } +} + +// debugMsg creates a debug message that gets send back to the client. +func debugMsg(rcode int, r *dns.Msg) *dns.Msg { + answer := new(dns.Msg) + answer.SetRcode(r, rcode) + return answer +} + +const timeFormat = "02/Jan/2006:15:04:05 -0700" diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go new file mode 100644 index 000000000..4434e835c --- /dev/null +++ b/middleware/errors/errors_test.go @@ -0,0 +1,168 @@ +package errors + +import ( + "bytes" + "errors" + "fmt" + "log" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/miekg/coredns/middleware" +) + +func TestErrors(t *testing.T) { + // create a temporary page + path := filepath.Join(os.TempDir(), "errors_test.html") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + defer os.Remove(path) + + const content = "This is a error page" + _, err = f.WriteString(content) + if err != nil { + t.Fatal(err) + } + f.Close() + + buf := bytes.Buffer{} + em := ErrorHandler{ + ErrorPages: map[int]string{ + http.StatusNotFound: path, + http.StatusForbidden: "not_exist_file", + }, + Log: log.New(&buf, "", 0), + } + _, notExistErr := os.Open("not_exist_file") + + testErr := errors.New("test error") + tests := []struct { + next middleware.Handler + expectedCode int + expectedBody string + expectedLog string + expectedErr error + }{ + { + next: genErrorHandler(http.StatusOK, nil, "normal"), + expectedCode: http.StatusOK, + expectedBody: "normal", + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusMovedPermanently, testErr, ""), + expectedCode: http.StatusMovedPermanently, + expectedBody: "", + expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr), + expectedErr: testErr, + }, + { + next: genErrorHandler(http.StatusBadRequest, nil, ""), + expectedCode: 0, + expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest, + http.StatusText(http.StatusBadRequest)), + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusNotFound, nil, ""), + expectedCode: 0, + expectedBody: content, + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusForbidden, nil, ""), + expectedCode: 0, + expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden, + http.StatusText(http.StatusForbidden)), + expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n", + http.StatusForbidden, notExistErr), + expectedErr: nil, + }, + } + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + for i, test := range tests { + em.Next = test.next + buf.Reset() + rec := httptest.NewRecorder() + code, err := em.ServeHTTP(rec, req) + + if err != test.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", + i, test.expectedErr, err) + } + if code != test.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", + i, test.expectedCode, code) + } + if body := rec.Body.String(); body != test.expectedBody { + t.Errorf("Test %d: Expected body %q, but got %q", + i, test.expectedBody, body) + } + if log := buf.String(); !strings.Contains(log, test.expectedLog) { + t.Errorf("Test %d: Expected log %q, but got %q", + i, test.expectedLog, log) + } + } +} + +func TestVisibleErrorWithPanic(t *testing.T) { + const panicMsg = "I'm a panic" + eh := ErrorHandler{ + ErrorPages: make(map[int]string), + Debug: true, + Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + panic(panicMsg) + }), + } + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + rec := httptest.NewRecorder() + + code, err := eh.ServeHTTP(rec, req) + + if code != 0 { + t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code) + } + if err != nil { + t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err) + } + + body := rec.Body.String() + + if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") { + t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body) + } + if !strings.Contains(body, panicMsg) { + t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body) + } + if len(body) < 500 { + t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body)) + } +} + +func genErrorHandler(status int, err error, body string) middleware.Handler { + return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + if len(body) > 0 { + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + fmt.Fprint(w, body) + } + return status, err + }) +} diff --git a/middleware/etcd/TODO b/middleware/etcd/TODO new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/middleware/etcd/TODO diff --git a/middleware/exchange.go b/middleware/exchange.go new file mode 100644 index 000000000..837fa3cdc --- /dev/null +++ b/middleware/exchange.go @@ -0,0 +1,10 @@ +package middleware + +import "github.com/miekg/dns" + +// Exchang sends message m to the server. +// TODO(miek): optionally it can do retries of other silly stuff. +func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) { + r, _, err := c.Exchange(m, server) + return r, err +} diff --git a/middleware/file/file.go b/middleware/file/file.go new file mode 100644 index 000000000..5bc5a3a3a --- /dev/null +++ b/middleware/file/file.go @@ -0,0 +1,89 @@ +package file + +// TODO(miek): the zone's implementation is basically non-existent +// we return a list and when searching for an answer we iterate +// over the list. This must be moved to a tree-like structure and +// have some fluff for DNSSEC (and be memory efficient). + +import ( + "strings" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +type ( + File struct { + Next middleware.Handler + Zones Zones + // Maybe a list of all zones as well, as a []string? + } + + Zone []dns.RR + Zones struct { + Z map[string]Zone // utterly braindead impl. TODO(miek): fix + Names []string + } +) + +func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + context := middleware.Context{W: w, Req: r} + qname := context.Name() + zone := middleware.Zones(f.Zones.Names).Matches(qname) + if zone == "" { + return f.Next.ServeDNS(w, r) + } + + names, nodata := f.Zones.Z[zone].lookup(qname, context.QType()) + var answer *dns.Msg + switch { + case nodata: + answer = context.AnswerMessage() + answer.Ns = names + case len(names) == 0: + answer = context.AnswerMessage() + answer.Ns = names + answer.Rcode = dns.RcodeNameError + case len(names) > 0: + answer = context.AnswerMessage() + answer.Answer = names + default: + answer = context.ErrorMessage(dns.RcodeServerFailure) + } + // Check return size, etc. TODO(miek) + w.WriteMsg(answer) + return 0, nil +} + +// Lookup will try to find qname and qtype in z. It returns the +// records found *or* a boolean saying NODATA. If the answer +// is NODATA then the RR returned is the SOA record. +// +// TODO(miek): EXTREMELY STUPID IMPLEMENTATION. +// Doesn't do much, no delegation, no cname, nothing really, etc. +// TODO(miek): even NODATA looks broken +func (z Zone) lookup(qname string, qtype uint16) ([]dns.RR, bool) { + var ( + nodata bool + rep []dns.RR + soa dns.RR + ) + + for _, rr := range z { + if rr.Header().Rrtype == dns.TypeSOA { + soa = rr + } + // Match function in Go DNS? + if strings.ToLower(rr.Header().Name) == qname { + if rr.Header().Rrtype == qtype { + rep = append(rep, rr) + nodata = false + } + + } + } + if nodata { + return []dns.RR{soa}, true + } + return rep, false +} diff --git a/middleware/file/file_test.go b/middleware/file/file_test.go new file mode 100644 index 000000000..54584b5cc --- /dev/null +++ b/middleware/file/file_test.go @@ -0,0 +1,325 @@ +package file + +import ( + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +var testDir = filepath.Join(os.TempDir(), "caddy_testdir") +var ErrCustom = errors.New("Custom Error") + +// testFiles is a map with relative paths to test files as keys and file content as values. +// The map represents the following structure: +// - $TEMP/caddy_testdir/ +// '-- file1.html +// '-- dirwithindex/ +// '---- index.html +// '-- dir/ +// '---- file2.html +// '---- hidden.html +var testFiles = map[string]string{ + "file1.html": "<h1>file1.html</h1>", + filepath.Join("dirwithindex", "index.html"): "<h1>dirwithindex/index.html</h1>", + filepath.Join("dir", "file2.html"): "<h1>dir/file2.html</h1>", + filepath.Join("dir", "hidden.html"): "<h1>dir/hidden.html</h1>", +} + +// TestServeHTTP covers positive scenarios when serving files. +func TestServeHTTP(t *testing.T) { + + beforeServeHTTPTest(t) + defer afterServeHTTPTest(t) + + fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"}) + + movedPermanently := "Moved Permanently" + + tests := []struct { + url string + + expectedStatus int + expectedBodyContent string + }{ + // Test 0 - access without any path + { + url: "https://foo", + expectedStatus: http.StatusNotFound, + }, + // Test 1 - access root (without index.html) + { + url: "https://foo/", + expectedStatus: http.StatusNotFound, + }, + // Test 2 - access existing file + { + url: "https://foo/file1.html", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles["file1.html"], + }, + // Test 3 - access folder with index file with trailing slash + { + url: "https://foo/dirwithindex/", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")], + }, + // Test 4 - access folder with index file without trailing slash + { + url: "https://foo/dirwithindex", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 5 - access folder without index file + { + url: "https://foo/dir/", + expectedStatus: http.StatusNotFound, + }, + // Test 6 - access folder without trailing slash + { + url: "https://foo/dir", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 6 - access file with trailing slash + { + url: "https://foo/file1.html/", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 7 - access not existing path + { + url: "https://foo/not_existing", + expectedStatus: http.StatusNotFound, + }, + // Test 8 - access a file, marked as hidden + { + url: "https://foo/dir/hidden.html", + expectedStatus: http.StatusNotFound, + }, + // Test 9 - access a index file directly + { + url: "https://foo/dirwithindex/index.html", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")], + }, + // Test 10 - send a request with query params + { + url: "https://foo/dir?param1=val", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + } + + for i, test := range tests { + responseRecorder := httptest.NewRecorder() + request, err := http.NewRequest("GET", test.url, strings.NewReader("")) + status, err := fileserver.ServeHTTP(responseRecorder, request) + + // check if error matches expectations + if err != nil { + t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err) + } + + // check status code + if test.expectedStatus != status { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check body content + if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) { + t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String()) + } + } + +} + +// beforeServeHTTPTest creates a test directory with the structure, defined in the variable testFiles +func beforeServeHTTPTest(t *testing.T) { + // make the root test dir + err := os.Mkdir(testDir, os.ModePerm) + if err != nil { + if !os.IsExist(err) { + t.Fatalf("Failed to create test dir. Error was: %v", err) + return + } + } + + for relFile, fileContent := range testFiles { + absFile := filepath.Join(testDir, relFile) + + // make sure the parent directories exist + parentDir := filepath.Dir(absFile) + _, err = os.Stat(parentDir) + if err != nil { + os.MkdirAll(parentDir, os.ModePerm) + } + + // now create the test files + f, err := os.Create(absFile) + if err != nil { + t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err) + return + } + + // and fill them with content + _, err = f.WriteString(fileContent) + if err != nil { + t.Fatalf("Failed to write to %s. Error was: %v", absFile, err) + return + } + f.Close() + } + +} + +// afterServeHTTPTest removes the test dir and all its content +func afterServeHTTPTest(t *testing.T) { + // cleans up everything under the test dir. No need to clean the individual files. + err := os.RemoveAll(testDir) + if err != nil { + t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err) + } +} + +// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err +type failingFS struct { + err error // the error to return when Open is called + fileImpl http.File // inject the file implementation +} + +// Open returns the assigned failingFile and error +func (f failingFS) Open(path string) (http.File, error) { + return f.fileImpl, f.err +} + +// failingFile implements http.File but returns a predefined error on every Stat() method call. +type failingFile struct { + http.File + err error +} + +// Stat returns nil FileInfo and the provided error on every call +func (ff failingFile) Stat() (os.FileInfo, error) { + return nil, ff.err +} + +// Close is noop and returns no error +func (ff failingFile) Close() error { + return nil +} + +// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors. +func TestServeHTTPFailingFS(t *testing.T) { + + tests := []struct { + fsErr error + expectedStatus int + expectedErr error + expectedHeaders map[string]string + }{ + { + fsErr: os.ErrNotExist, + expectedStatus: http.StatusNotFound, + expectedErr: nil, + }, + { + fsErr: os.ErrPermission, + expectedStatus: http.StatusForbidden, + expectedErr: os.ErrPermission, + }, + { + fsErr: ErrCustom, + expectedStatus: http.StatusServiceUnavailable, + expectedErr: ErrCustom, + expectedHeaders: map[string]string{"Retry-After": "5"}, + }, + } + + for i, test := range tests { + // initialize a file server with the failing FileSystem + fileserver := FileServer(failingFS{err: test.fsErr}, nil) + + // prepare the request and response + request, err := http.NewRequest("GET", "https://foo/", nil) + if err != nil { + t.Fatalf("Failed to build request. Error was: %v", err) + } + responseRecorder := httptest.NewRecorder() + + status, actualErr := fileserver.ServeHTTP(responseRecorder, request) + + // check the status + if status != test.expectedStatus { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check the error + if actualErr != test.expectedErr { + t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + } + + // check the headers - a special case for server under load + if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 { + for expectedKey, expectedVal := range test.expectedHeaders { + actualVal := responseRecorder.Header().Get(expectedKey) + if expectedVal != actualVal { + t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal) + } + } + } + } +} + +// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails. +func TestServeHTTPFailingStat(t *testing.T) { + + tests := []struct { + statErr error + expectedStatus int + expectedErr error + }{ + { + statErr: os.ErrNotExist, + expectedStatus: http.StatusNotFound, + expectedErr: nil, + }, + { + statErr: os.ErrPermission, + expectedStatus: http.StatusForbidden, + expectedErr: os.ErrPermission, + }, + { + statErr: ErrCustom, + expectedStatus: http.StatusInternalServerError, + expectedErr: ErrCustom, + }, + } + + for i, test := range tests { + // initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will + fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil) + + // prepare the request and response + request, err := http.NewRequest("GET", "https://foo/", nil) + if err != nil { + t.Fatalf("Failed to build request. Error was: %v", err) + } + responseRecorder := httptest.NewRecorder() + + status, actualErr := fileserver.ServeHTTP(responseRecorder, request) + + // check the status + if status != test.expectedStatus { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check the error + if actualErr != test.expectedErr { + t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + } + } +} diff --git a/middleware/host.go b/middleware/host.go new file mode 100644 index 000000000..17ecedb5f --- /dev/null +++ b/middleware/host.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net" + "strings" + + "github.com/miekg/dns" +) + +// Host represents a host from the Caddyfile, may contain port. +type Host string + +// Standard host will return the host portion of host, stripping +// of any port. The host will also be fully qualified and lowercased. +func (h Host) StandardHost() string { + // separate host and port + host, _, err := net.SplitHostPort(string(h)) + if err != nil { + host, _, _ = net.SplitHostPort(string(h) + ":") + } + return strings.ToLower(dns.Fqdn(host)) +} diff --git a/middleware/log/log.go b/middleware/log/log.go new file mode 100644 index 000000000..109add9f5 --- /dev/null +++ b/middleware/log/log.go @@ -0,0 +1,66 @@ +// Package log implements basic but useful request (access) logging middleware. +package log + +import ( + "log" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// Logger is a basic request logging middleware. +type Logger struct { + Next middleware.Handler + Rules []Rule + ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler +} + +func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + for _, rule := range l.Rules { + /* + if middleware.Path(r.URL.Path).Matches(rule.PathScope) { + responseRecorder := middleware.NewResponseRecorder(w) + status, err := l.Next.ServeHTTP(responseRecorder, r) + if status >= 400 { + // There was an error up the chain, but no response has been written yet. + // The error must be handled here so the log entry will record the response size. + if l.ErrorFunc != nil { + l.ErrorFunc(responseRecorder, r, status) + } else { + // Default failover error handler + responseRecorder.WriteHeader(status) + fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status)) + } + status = 0 + } + rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue) + rule.Log.Println(rep.Replace(rule.Format)) + return status, err + } + */ + rule = rule + } + return l.Next.ServeDNS(w, r) +} + +// Rule configures the logging middleware. +type Rule struct { + PathScope string + OutputFile string + Format string + Log *log.Logger + Roller *middleware.LogRoller +} + +const ( + // DefaultLogFilename is the default log filename. + DefaultLogFilename = "access.log" + // CommonLogFormat is the common log format. + CommonLogFormat = `{remote} ` + CommonLogEmptyValue + ` [{when}] "{type} {name} {proto}" {rcode} {size}` + // CommonLogEmptyValue is the common empty log value. + CommonLogEmptyValue = "-" + // CombinedLogFormat is the combined log format. + CombinedLogFormat = CommonLogFormat + ` "{>Referer}" "{>User-Agent}"` // Something here as well + // DefaultLogFormat is the default log format. + DefaultLogFormat = CommonLogFormat +) diff --git a/middleware/log/log_test.go b/middleware/log/log_test.go new file mode 100644 index 000000000..40560e4c0 --- /dev/null +++ b/middleware/log/log_test.go @@ -0,0 +1,48 @@ +package log + +import ( + "bytes" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type erroringMiddleware struct{} + +func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + return http.StatusNotFound, nil +} + +func TestLoggedStatus(t *testing.T) { + var f bytes.Buffer + var next erroringMiddleware + rule := Rule{ + PathScope: "/", + Format: DefaultLogFormat, + Log: log.New(&f, "", 0), + } + + logger := Logger{ + Rules: []Rule{rule}, + Next: next, + } + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + rec := httptest.NewRecorder() + + status, err := logger.ServeHTTP(rec, r) + if status != 0 { + t.Error("Expected status to be 0 - was", status) + } + + logged := f.String() + if !strings.Contains(logged, "404 13") { + t.Error("Expected 404 to be logged. Logged string -", logged) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 000000000..436ec86e9 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,105 @@ +// Package middleware provides some types and functions common among middleware. +package middleware + +import ( + "time" + + "github.com/miekg/dns" +) + +type ( + // Middleware is the middle layer which represents the traditional + // idea of middleware: it chains one Handler to the next by being + // passed the next Handler in the chain. + Middleware func(Handler) Handler + + // Handler is like dns.Handler except ServeDNS may return an rcode + // and/or error. + // + // If ServeDNS writes to the response body, it should return a status + // code of 0. This signals to other handlers above it that the response + // body is already written, and that they should not write to it also. + // + // If ServeDNS encounters an error, it should return the error value + // so it can be logged by designated error-handling middleware. + // + // If writing a response after calling another ServeDNS method, the + // returned rcode SHOULD be used when writing the response. + // + // If handling errors after calling another ServeDNS method, the + // returned error value SHOULD be logged or handled accordingly. + // + // Otherwise, return values should be propagated down the middleware + // chain by returning them unchanged. + Handler interface { + ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error) + } + + // HandlerFunc is a convenience type like dns.HandlerFunc, except + // ServeDNS returns an rcode and an error. See Handler + // documentation for more information. + HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error) +) + +// ServeDNS implements the Handler interface. +func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + return f(w, r) +} + +// IndexFile looks for a file in /root/fpath/indexFile for each string +// in indexFiles. If an index file is found, it returns the root-relative +// path to the file and true. If no index file is found, empty string +// and false is returned. fpath must end in a forward slash '/' +// otherwise no index files will be tried (directory paths must end +// in a forward slash according to HTTP). +// +// All paths passed into and returned from this function use '/' as the +// path separator, just like URLs. IndexFle handles path manipulation +// internally for systems that use different path separators. +/* +func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, bool) { + if fpath[len(fpath)-1] != '/' || root == nil { + return "", false + } + for _, indexFile := range indexFiles { + // func (http.FileSystem).Open wants all paths separated by "/", + // regardless of operating system convention, so use + // path.Join instead of filepath.Join + fp := path.Join(fpath, indexFile) + f, err := root.Open(fp) + if err == nil { + f.Close() + return fp, true + } + } + return "", false +} + +// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it +// as a Last-Modified header to the ResponseWriter. If the modTime is in the future +// the current time is used instead. +func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) { + if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) { + // the time does not appear to be valid. Don't put it in the response + return + } + + // RFC 2616 - Section 14.29 - Last-Modified: + // An origin server MUST NOT send a Last-Modified date which is later than the + // server's time of message origination. In such cases, where the resource's last + // modification would indicate some time in the future, the server MUST replace + // that date with the message origination date. + now := currentTime() + if modTime.After(now) { + modTime = now + } + + w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat)) +} +*/ + +// currentTime, as it is defined here, returns time.Now(). +// It's defined as a variable for mocking time in tests. +var currentTime = func() time.Time { + return time.Now() +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 000000000..62fa4e250 --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,108 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestIndexfile(t *testing.T) { + tests := []struct { + rootDir http.FileSystem + fpath string + indexFiles []string + shouldErr bool + expectedFilePath string //retun value + expectedBoolValue bool //return value + }{ + { + http.Dir("./templates/testdata"), + "/images/", + []string{"img.htm"}, + false, + "/images/img.htm", + true, + }, + } + for i, test := range tests { + actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles) + if actualBoolValue == true && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if actualBoolValue != true && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist") + } + if actualFilePath != test.expectedFilePath { + t.Fatalf("Test %d expected returned filepath to be %s, but got %s ", + i, test.expectedFilePath, actualFilePath) + + } + if actualBoolValue != test.expectedBoolValue { + t.Fatalf("Test %d expected returned bool value to be %v, but got %v ", + i, test.expectedBoolValue, actualBoolValue) + + } + } +} + +func TestSetLastModified(t *testing.T) { + nowTime := time.Now() + + // ovewrite the function to return reliable time + originalGetCurrentTimeFunc := currentTime + currentTime = func() time.Time { + return nowTime + } + defer func() { + currentTime = originalGetCurrentTimeFunc + }() + + pastTime := nowTime.Truncate(1 * time.Hour) + futureTime := nowTime.Add(1 * time.Hour) + + tests := []struct { + inputModTime time.Time + expectedIsHeaderSet bool + expectedLastModified string + }{ + { + inputModTime: pastTime, + expectedIsHeaderSet: true, + expectedLastModified: pastTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: nowTime, + expectedIsHeaderSet: true, + expectedLastModified: nowTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: futureTime, + expectedIsHeaderSet: true, + expectedLastModified: nowTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: time.Time{}, + expectedIsHeaderSet: false, + }, + } + + for i, test := range tests { + responseRecorder := httptest.NewRecorder() + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + SetLastModifiedHeader(responseRecorder, test.inputModTime) + actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified") + + if test.expectedIsHeaderSet && actualLastModifiedHeader == "" { + t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing") + } + + if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" { + t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader) + } + + if test.expectedLastModified != actualLastModifiedHeader { + t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader) + } + } +} diff --git a/middleware/path.go b/middleware/path.go new file mode 100644 index 000000000..1ffb64b76 --- /dev/null +++ b/middleware/path.go @@ -0,0 +1,18 @@ +package middleware + +import "strings" + + +// TODO(miek): matches for names. + +// Path represents a URI path, maybe with pattern characters. +type Path string + +// Matches checks to see if other matches p. +// +// Path matching will probably not always be a direct +// comparison; this method assures that paths can be +// easily and consistently matched. +func (p Path) Matches(other string) bool { + return strings.HasPrefix(string(p), other) +} diff --git a/middleware/prometheus/handler.go b/middleware/prometheus/handler.go new file mode 100644 index 000000000..eb82b8aff --- /dev/null +++ b/middleware/prometheus/handler.go @@ -0,0 +1,31 @@ +package metrics + +import ( + "strconv" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + context := middleware.Context{W: w, Req: r} + + qname := context.Name() + qtype := context.Type() + zone := middleware.Zones(m.ZoneNames).Matches(qname) + if zone == "" { + zone = "." + } + + // Record response to get status code and size of the reply. + rw := middleware.NewResponseRecorder(w) + status, err := m.Next.ServeDNS(rw, r) + + requestCount.WithLabelValues(zone, qtype).Inc() + requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second)) + responseSize.WithLabelValues(zone).Observe(float64(rw.Size())) + responseRcode.WithLabelValues(zone, strconv.Itoa(rw.Rcode())).Inc() + + return status, err +} diff --git a/middleware/prometheus/metrics.go b/middleware/prometheus/metrics.go new file mode 100644 index 000000000..4c989f640 --- /dev/null +++ b/middleware/prometheus/metrics.go @@ -0,0 +1,80 @@ +package metrics + +import ( + "fmt" + "net/http" + "sync" + + "github.com/miekg/coredns/middleware" + "github.com/prometheus/client_golang/prometheus" +) + +const namespace = "daddy" + +var ( + requestCount *prometheus.CounterVec + requestDuration *prometheus.HistogramVec + responseSize *prometheus.HistogramVec + responseRcode *prometheus.CounterVec +) + +const path = "/metrics" + +// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics +type Metrics struct { + Next middleware.Handler + Addr string // where to we listen + Once sync.Once + ZoneNames []string +} + +func (m *Metrics) Start() error { + m.Once.Do(func() { + define("") + + prometheus.MustRegister(requestCount) + prometheus.MustRegister(requestDuration) + prometheus.MustRegister(responseSize) + prometheus.MustRegister(responseRcode) + + http.Handle(path, prometheus.Handler()) + go func() { + fmt.Errorf("%s", http.ListenAndServe(m.Addr, nil)) + }() + }) + return nil +} + +func define(subsystem string) { + if subsystem == "" { + subsystem = "dns" + } + requestCount = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "request_count_total", + Help: "Counter of DNS requests made per zone and type.", + }, []string{"zone", "qtype"}) + + requestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "request_duration_seconds", + Help: "Histogram of the time (in seconds) each request took.", + }, []string{"zone"}) + + responseSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "response_size_bytes", + Help: "Size of the returns response in bytes.", + Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3}, + }, []string{"zone"}) + + responseRcode = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "rcode_code_count_total", + Help: "Counter of response status codes.", + }, []string{"zone", "rcode"}) +} diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go new file mode 100644 index 000000000..a2522bcb1 --- /dev/null +++ b/middleware/proxy/policy.go @@ -0,0 +1,101 @@ +package proxy + +import ( + "math/rand" + "sync/atomic" +) + +// HostPool is a collection of UpstreamHosts. +type HostPool []*UpstreamHost + +// Policy decides how a host will be selected from a pool. +type Policy interface { + Select(pool HostPool) *UpstreamHost +} + +func init() { + RegisterPolicy("random", func() Policy { return &Random{} }) + RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) +} + +// Random is a policy that selects up hosts from a pool at random. +type Random struct{} + +// Select selects an up host at random from the specified pool. +func (r *Random) Select(pool HostPool) *UpstreamHost { + // instead of just generating a random index + // this is done to prevent selecting a down host + var randHost *UpstreamHost + count := 0 + for _, host := range pool { + if host.Down() { + continue + } + count++ + if count == 1 { + randHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + randHost = host + } + } + } + return randHost +} + +// LeastConn is a policy that selects the host with the least connections. +type LeastConn struct{} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (r *LeastConn) Select(pool HostPool) *UpstreamHost { + var bestHost *UpstreamHost + count := 0 + leastConn := int64(1<<63 - 1) + for _, host := range pool { + if host.Down() { + continue + } + hostConns := host.Conns + if hostConns < leastConn { + bestHost = host + leastConn = hostConns + count = 1 + } else if hostConns == leastConn { + // randomly select host among hosts with least connections + count++ + if count == 1 { + bestHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + bestHost = host + } + } + } + } + return bestHost +} + +// RoundRobin is a policy that selects hosts based on round robin ordering. +type RoundRobin struct { + Robin uint32 +} + +// Select selects an up host from the pool using a round robin ordering scheme. +func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { + poolLen := uint32(len(pool)) + selection := atomic.AddUint32(&r.Robin, 1) % poolLen + host := pool[selection] + // if the currently selected host is down, just ffwd to up host + for i := uint32(1); host.Down() && i < poolLen; i++ { + host = pool[(selection+i)%poolLen] + } + if host.Down() { + return nil + } + return host +} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go new file mode 100644 index 000000000..8f4f1f792 --- /dev/null +++ b/middleware/proxy/policy_test.go @@ -0,0 +1,87 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +var workableServer *httptest.Server + +func TestMain(m *testing.M) { + workableServer = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // do nothing + })) + r := m.Run() + workableServer.Close() + os.Exit(r) +} + +type customPolicy struct{} + +func (r *customPolicy) Select(pool HostPool) *UpstreamHost { + return pool[0] +} + +func testPool() HostPool { + pool := []*UpstreamHost{ + { + Name: workableServer.URL, // this should resolve (healthcheck test) + }, + { + Name: "http://shouldnot.resolve", // this shouldn't + }, + { + Name: "http://C", + }, + } + return HostPool(pool) +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := &RoundRobin{} + h := rrPolicy.Select(pool) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + // mark host as down + pool[0].Unhealthy = true + h = rrPolicy.Select(pool) + if h != pool[1] { + t.Error("Expected third round robin host to be first host in the pool.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := &LeastConn{} + pool[0].Conns = 10 + pool[1].Conns = 10 + h := lcPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].Conns = 100 + h = lcPolicy.Select(pool) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} + +func TestCustomPolicy(t *testing.T) { + pool := testPool() + customPolicy := &customPolicy{} + h := customPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected custom policy host to be the first host.") + } +} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go new file mode 100644 index 000000000..169e41b61 --- /dev/null +++ b/middleware/proxy/proxy.go @@ -0,0 +1,120 @@ +// Package proxy is middleware that proxies requests. +package proxy + +import ( + "errors" + "net/http" + "sync/atomic" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +var errUnreachable = errors.New("unreachable backend") + +// Proxy represents a middleware instance that can proxy requests. +type Proxy struct { + Next middleware.Handler + Client Client + Upstreams []Upstream +} + +type Client struct { + UDP *dns.Client + TCP *dns.Client +} + +// Upstream manages a pool of proxy upstream hosts. Select should return a +// suitable upstream host, or nil if no such hosts are available. +type Upstream interface { + // The domain name this upstream host should be routed on. + From() string + // Selects an upstream host to be routed to. + Select() *UpstreamHost + // Checks if subpdomain is not an ignored. + IsAllowedPath(string) bool +} + +// UpstreamHostDownFunc can be used to customize how Down behaves. +type UpstreamHostDownFunc func(*UpstreamHost) bool + +// UpstreamHost represents a single proxy upstream +type UpstreamHost struct { + Conns int64 // must be first field to be 64-bit aligned on 32-bit systems + Name string // IP address (and port) of this upstream host + Fails int32 + FailTimeout time.Duration + Unhealthy bool + ExtraHeaders http.Header + CheckDown UpstreamHostDownFunc + WithoutPathPrefix string +} + +// Down checks whether the upstream host is down or not. +// Down will try to use uh.CheckDown first, and will fall +// back to some default criteria if necessary. +func (uh *UpstreamHost) Down() bool { + if uh.CheckDown == nil { + // Default settings + return uh.Unhealthy || uh.Fails > 0 + } + return uh.CheckDown(uh) +} + +// tryDuration is how long to try upstream hosts; failures result in +// immediate retries until this duration ends or we get a nil host. +var tryDuration = 60 * time.Second + +// ServeDNS satisfies the middleware.Handler interface. +func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + for _, upstream := range p.Upstreams { + // allowed bla bla bla TODO(miek): fix full proxy spec from caddy + start := time.Now() + + // Since Select() should give us "up" hosts, keep retrying + // hosts until timeout (or until we get a nil host). + for time.Now().Sub(start) < tryDuration { + host := upstream.Select() + if host == nil { + return dns.RcodeServerFailure, errUnreachable + } + // TODO(miek): PORT! + reverseproxy := ReverseProxy{Host: host.Name, Client: p.Client} + + atomic.AddInt64(&host.Conns, 1) + backendErr := reverseproxy.ServeDNS(w, r, nil) + atomic.AddInt64(&host.Conns, -1) + if backendErr == nil { + return 0, nil + } + timeout := host.FailTimeout + if timeout == 0 { + timeout = 10 * time.Second + } + atomic.AddInt32(&host.Fails, 1) + go func(host *UpstreamHost, timeout time.Duration) { + time.Sleep(timeout) + atomic.AddInt32(&host.Fails, -1) + }(host, timeout) + } + return dns.RcodeServerFailure, errUnreachable + } + return p.Next.ServeDNS(w, r) +} + +func Clients() Client { + udp := newClient("udp", defaultTimeout) + tcp := newClient("tcp", defaultTimeout) + return Client{UDP: udp, TCP: tcp} +} + +// newClient returns a new client for proxy requests. +func newClient(net string, timeout time.Duration) *dns.Client { + if timeout == 0 { + timeout = defaultTimeout + } + return &dns.Client{Net: net, ReadTimeout: timeout, WriteTimeout: timeout, SingleInflight: true} +} + +const defaultTimeout = 5 * time.Second diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go new file mode 100644 index 000000000..8066874d2 --- /dev/null +++ b/middleware/proxy/proxy_test.go @@ -0,0 +1,317 @@ +package proxy + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +func init() { + tryDuration = 50 * time.Millisecond // prevent tests from hanging +} + +func TestReverseProxy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Expected backend to receive request, but it didn't") + } +} + +func TestReverseProxyInsecureSkipVerify(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, true)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't") + } +} + +func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { + // No-op websocket backend simply allows the WS connection to be + // accepted then it will be immediately closed. Perfect for testing. + wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) + defer wsNop.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsNop.URL) + + // Create client request + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + r.Header = http.Header{ + "Connection": {"Upgrade"}, + "Upgrade": {"websocket"}, + "Origin": {wsNop.URL}, + "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, + "Sec-WebSocket-Version": {"13"}, + } + + // Capture the request + w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)} + + // Booya! Do the test. + p.ServeHTTP(w, r) + + // Make sure the backend accepted the WS connection. + // Mostly interested in the Upgrade and Connection response headers + // and the 101 status code. + expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n") + actual := w.fakeConn.writeBuf.Bytes() + if !bytes.Equal(actual, expected) { + t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) + } +} + +func TestWebSocketReverseProxyFromWSClient(t *testing.T) { + // Echo server allows us to test that socket bytes are properly + // being proxied. + wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsEcho.URL) + + // This is a full end-end test, so the proxy handler + // has to be part of a server listening on a port. Our + // WS client will connect to this test server, not + // the echo client directly. + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "http://", "ws://", 1) + ws, err := websocket.Dial(url, "", echoProxy.URL) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + websocket.Message.Send(ws, trialMsg) + + // It should be echoed back to us + var actualMsg string + websocket.Message.Receive(ws, &actualMsg) + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func TestUnixSocketProxy(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + + trialMsg := "Is it working?" + + var proxySuccess bool + + // This is our fake "application" we want to proxy to + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Request was proxied when this is called + proxySuccess = true + + fmt.Fprint(w, trialMsg) + })) + + // Get absolute path for unix: socket + socketPath, err := filepath.Abs("./test_socket") + if err != nil { + t.Fatalf("Unable to get absolute path: %v", err) + } + + // Change httptest.Server listener to listen to unix: socket + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Unable to listen: %v", err) + } + ts.Listener = ln + + ts.Start() + defer ts.Close() + + url := strings.Replace(ts.URL, "http://", "unix:", 1) + p := newWebSocketTestProxy(url) + + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + res, err := http.Get(echoProxy.URL) + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + actualMsg := fmt.Sprintf("%s", greeting) + + if !proxySuccess { + t.Errorf("Expected request to be proxied, but it wasn't") + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func newFakeUpstream(name string, insecure bool) *fakeUpstream { + uri, _ := url.Parse(name) + u := &fakeUpstream{ + name: name, + host: &UpstreamHost{ + Name: name, + ReverseProxy: NewSingleHostReverseProxy(uri, ""), + }, + } + if insecure { + u.host.ReverseProxy.Transport = InsecureTransport + } + return u +} + +type fakeUpstream struct { + name string + host *UpstreamHost +} + +func (u *fakeUpstream) From() string { + return "/" +} + +func (u *fakeUpstream) Select() *UpstreamHost { + return u.host +} + +func (u *fakeUpstream) IsAllowedPath(requestPath string) bool { + return true +} + +// newWebSocketTestProxy returns a test proxy that will +// redirect to the specified backendAddr. The function +// also sets up the rules/environment for testing WebSocket +// proxy. +func newWebSocketTestProxy(backendAddr string) *Proxy { + return &Proxy{ + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}}, + } +} + +type fakeWsUpstream struct { + name string +} + +func (u *fakeWsUpstream) From() string { + return "/" +} + +func (u *fakeWsUpstream) Select() *UpstreamHost { + uri, _ := url.Parse(u.name) + return &UpstreamHost{ + Name: u.name, + ReverseProxy: NewSingleHostReverseProxy(uri, ""), + ExtraHeaders: http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}}, + } +} + +func (u *fakeWsUpstream) IsAllowedPath(requestPath string) bool { + return true +} + +// recorderHijacker is a ResponseRecorder that can +// be hijacked. +type recorderHijacker struct { + *httptest.ResponseRecorder + fakeConn *fakeConn +} + +func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rh.fakeConn, nil, nil +} + +type fakeConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +func (c *fakeConn) LocalAddr() net.Addr { return nil } +func (c *fakeConn) RemoteAddr() net.Addr { return nil } +func (c *fakeConn) SetDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } +func (c *fakeConn) Close() error { return nil } +func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } +func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go new file mode 100644 index 000000000..6d27da042 --- /dev/null +++ b/middleware/proxy/reverseproxy.go @@ -0,0 +1,36 @@ +// Package proxy is middleware that proxies requests. +package proxy + +import ( + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +type ReverseProxy struct { + Host string + Client Client +} + +func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error { + // TODO(miek): use extra! + var ( + reply *dns.Msg + err error + ) + context := middleware.Context{W: w, Req: r} + + // tls+tcp ? + if context.Proto() == "tcp" { + reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) + } else { + reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) + } + + if err != nil { + return err + } + reply.Compress = true + reply.Id = r.Id + w.WriteMsg(reply) + return nil +} diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go new file mode 100644 index 000000000..092e2351d --- /dev/null +++ b/middleware/proxy/upstream.go @@ -0,0 +1,235 @@ +package proxy + +import ( + "io" + "io/ioutil" + "net/http" + "path" + "strconv" + "time" + + "github.com/miekg/coredns/core/parse" + "github.com/miekg/coredns/middleware" +) + +var ( + supportedPolicies = make(map[string]func() Policy) +) + +type staticUpstream struct { + from string + // TODO(miek): allows use to added headers + proxyHeaders http.Header // TODO(miek): kill + Hosts HostPool + Policy Policy + + FailTimeout time.Duration + MaxFails int32 + HealthCheck struct { + Path string + Interval time.Duration + } + WithoutPathPrefix string + IgnoredSubPaths []string +} + +// NewStaticUpstreams parses the configuration input and sets up +// static upstreams for the proxy middleware. +func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { + var upstreams []Upstream + for c.Next() { + upstream := &staticUpstream{ + from: "", + proxyHeaders: make(http.Header), + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + + if !c.Args(&upstream.from) { + return upstreams, c.ArgErr() + } + to := c.RemainingArgs() + if len(to) == 0 { + return upstreams, c.ArgErr() + } + + for c.NextBlock() { + if err := parseBlock(&c, upstream); err != nil { + return upstreams, err + } + } + + upstream.Hosts = make([]*UpstreamHost, len(to)) + for i, host := range to { + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: upstream.FailTimeout, + Unhealthy: false, + ExtraHeaders: upstream.proxyHeaders, + CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= upstream.MaxFails && + upstream.MaxFails != 0 { + return true + } + return false + } + }(upstream), + WithoutPathPrefix: upstream.WithoutPathPrefix, + } + upstream.Hosts[i] = uh + } + + if upstream.HealthCheck.Path != "" { + go upstream.HealthCheckWorker(nil) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy +} + +func (u *staticUpstream) From() string { + return u.from +} + +func parseBlock(c *parse.Dispenser, u *staticUpstream) error { + switch c.Val() { + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + policyCreateFunc, ok := supportedPolicies[c.Val()] + if !ok { + return c.ArgErr() + } + u.Policy = policyCreateFunc() + case "fail_timeout": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.FailTimeout = dur + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + u.MaxFails = int32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + u.HealthCheck.Path = c.Val() + u.HealthCheck.Interval = 30 * time.Second + if c.NextArg() { + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.HealthCheck.Interval = dur + } + case "proxy_header": + var header, value string + if !c.Args(&header, &value) { + return c.ArgErr() + } + u.proxyHeaders.Add(header, value) + case "websocket": + u.proxyHeaders.Add("Connection", "{>Connection}") + u.proxyHeaders.Add("Upgrade", "{>Upgrade}") + case "without": + if !c.NextArg() { + return c.ArgErr() + } + u.WithoutPathPrefix = c.Val() + case "except": + ignoredPaths := c.RemainingArgs() + if len(ignoredPaths) == 0 { + return c.ArgErr() + } + u.IgnoredSubPaths = ignoredPaths + default: + return c.Errf("unknown property '%s'", c.Val()) + } + return nil +} + +func (u *staticUpstream) healthCheck() { + for _, host := range u.Hosts { + hostURL := host.Name + u.HealthCheck.Path + if r, err := http.Get(hostURL); err == nil { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 + } else { + host.Unhealthy = true + } + } +} + +func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { + ticker := time.NewTicker(u.HealthCheck.Interval) + u.healthCheck() + for { + select { + case <-ticker.C: + u.healthCheck() + case <-stop: + // TODO: the library should provide a stop channel and global + // waitgroup to allow goroutines started by plugins a chance + // to clean themselves up. + } + } +} + +func (u *staticUpstream) Select() *UpstreamHost { + pool := u.Hosts + if len(pool) == 1 { + if pool[0].Down() { + return nil + } + return pool[0] + } + allDown := true + for _, host := range pool { + if !host.Down() { + allDown = false + break + } + } + if allDown { + return nil + } + + if u.Policy == nil { + return (&Random{}).Select(pool) + } + return u.Policy.Select(pool) +} + +func (u *staticUpstream) IsAllowedPath(requestPath string) bool { + for _, ignoredSubPath := range u.IgnoredSubPaths { + if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) { + return false + } + } + return true +} diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go new file mode 100644 index 000000000..5b2fdb1da --- /dev/null +++ b/middleware/proxy/upstream_test.go @@ -0,0 +1,83 @@ +package proxy + +import ( + "testing" + "time" +) + +func TestHealthCheck(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool(), + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.healthCheck() + if upstream.Hosts[0].Down() { + t.Error("Expected first host in testpool to not fail healthcheck.") + } + if !upstream.Hosts[1].Down() { + t.Error("Expected second host in testpool to fail healthcheck.") + } +} + +func TestSelect(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool()[:3], + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.Hosts[0].Unhealthy = true + upstream.Hosts[1].Unhealthy = true + upstream.Hosts[2].Unhealthy = true + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all host are down") + } + upstream.Hosts[2].Unhealthy = false + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } +} + +func TestRegisterPolicy(t *testing.T) { + name := "custom" + customPolicy := &customPolicy{} + RegisterPolicy(name, func() Policy { return customPolicy }) + if _, ok := supportedPolicies[name]; !ok { + t.Error("Expected supportedPolicies to have a custom policy.") + } + +} + +func TestAllowedPaths(t *testing.T) { + upstream := &staticUpstream{ + from: "/proxy", + IgnoredSubPaths: []string{"/download", "/static"}, + } + tests := []struct { + url string + expected bool + }{ + {"/proxy", true}, + {"/proxy/dl", true}, + {"/proxy/download", false}, + {"/proxy/download/static", false}, + {"/proxy/static", false}, + {"/proxy/static/download", false}, + {"/proxy/something/download", true}, + {"/proxy/something/static", true}, + {"/proxy//static", false}, + {"/proxy//static//download", false}, + {"/proxy//download", false}, + } + + for i, test := range tests { + isAllowed := upstream.IsAllowedPath(test.url) + if test.expected != isAllowed { + t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed) + } + } +} diff --git a/middleware/recorder.go b/middleware/recorder.go new file mode 100644 index 000000000..38a7e0e82 --- /dev/null +++ b/middleware/recorder.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "time" + + "github.com/miekg/dns" +) + +// ResponseRecorder is a type of ResponseWriter that captures +// the rcode code written to it and also the size of the message +// written in the response. A rcode code does not have +// to be written, however, in which case 0 must be assumed. +// It is best to have the constructor initialize this type +// with that default status code. +type ResponseRecorder struct { + dns.ResponseWriter + rcode int + size int + start time.Time +} + +// NewResponseRecorder makes and returns a new responseRecorder, +// which captures the DNS rcode from the ResponseWriter +// and also the length of the response message written through it. +func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder { + return &ResponseRecorder{ + ResponseWriter: w, + rcode: 0, + start: time.Now(), + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error { + r.rcode = res.Rcode + r.size = res.Len() + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *ResponseRecorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.size += n + } + return n, err +} + +// Size returns the size. +func (r *ResponseRecorder) Size() int { + return r.size +} + +// Rcode returns the rcode. +func (r *ResponseRecorder) Rcode() int { + return r.rcode +} + +// Start returns the start time of the ResponseRecorder. +func (r *ResponseRecorder) Start() time.Time { + return r.start +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *ResponseRecorder) Hijack() { + r.ResponseWriter.Hijack() + return +} diff --git a/middleware/recorder_test.go b/middleware/recorder_test.go new file mode 100644 index 000000000..a8c8a5d04 --- /dev/null +++ b/middleware/recorder_test.go @@ -0,0 +1,32 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewResponseRecorder(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + if !(recordRequest.ResponseWriter == w) { + t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n") + } + if recordRequest.status != http.StatusOK { + t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status) + } +} + +func TestWrite(t *testing.T) { + w := httptest.NewRecorder() + responseTestString := "test" + recordRequest := NewResponseRecorder(w) + buf := []byte(responseTestString) + recordRequest.Write(buf) + if recordRequest.size != len(buf) { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size) + } + if w.Body.String() != responseTestString { + t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String()) + } +} diff --git a/middleware/reflect/reflect.go b/middleware/reflect/reflect.go new file mode 100644 index 000000000..6d5847b81 --- /dev/null +++ b/middleware/reflect/reflect.go @@ -0,0 +1,84 @@ +// Reflect provides middleware that reflects back some client properties. +// This is the default middleware when Caddy is run without configuration. +// +// The left-most label must be `who`. +// When queried for type A (resp. AAAA), it sends back the IPv4 (resp. v6) address. +// In the additional section the port number and transport are shown. +// Basic use pattern: +// +// dig @localhost -p 1053 who.miek.nl A +// +// ;; ANSWER SECTION: +// who.miek.nl. 0 IN A 127.0.0.1 +// +// ;; ADDITIONAL SECTION: +// who.miek.nl. 0 IN TXT "Port: 56195 (udp)" +package reflect + +import ( + "errors" + "net" + "strings" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +type Reflect struct { + Next middleware.Handler +} + +func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + context := middleware.Context{Req: r, W: w} + + class := r.Question[0].Qclass + qname := r.Question[0].Name + i, ok := dns.NextLabel(qname, 0) + + if strings.ToLower(qname[:i]) != who || ok { + err := context.ErrorMessage(dns.RcodeFormatError) + w.WriteMsg(err) + return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError]) + } + + answer := new(dns.Msg) + answer.SetReply(r) + answer.Compress = true + answer.Authoritative = true + + ip := context.IP() + proto := context.Proto() + port, _ := context.Port() + family := context.Family() + var rr dns.RR + + switch family { + case 1: + rr = new(dns.A) + rr.(*dns.A).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeA, Class: class, Ttl: 0} + rr.(*dns.A).A = net.ParseIP(ip).To4() + case 2: + rr = new(dns.AAAA) + rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeAAAA, Class: class, Ttl: 0} + rr.(*dns.AAAA).AAAA = net.ParseIP(ip) + } + + t := new(dns.TXT) + t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0} + t.Txt = []string{"Port: " + port + " (" + proto + ")"} + + switch context.Type() { + case "TXT": + answer.Answer = append(answer.Answer, t) + answer.Extra = append(answer.Extra, rr) + default: + fallthrough + case "AAAA", "A": + answer.Answer = append(answer.Answer, rr) + answer.Extra = append(answer.Extra, t) + } + w.WriteMsg(answer) + return 0, nil +} + +const who = "who." diff --git a/middleware/reflect/reflect_test.go b/middleware/reflect/reflect_test.go new file mode 100644 index 000000000..477a3a573 --- /dev/null +++ b/middleware/reflect/reflect_test.go @@ -0,0 +1 @@ +package reflect diff --git a/middleware/replacer.go b/middleware/replacer.go new file mode 100644 index 000000000..133da74c5 --- /dev/null +++ b/middleware/replacer.go @@ -0,0 +1,98 @@ +package middleware + +import ( + "strconv" + "strings" + "time" + + "github.com/miekg/dns" +) + +// Replacer is a type which can replace placeholder +// substrings in a string with actual values from a +// http.Request and responseRecorder. Always use +// NewReplacer to get one of these. +type Replacer interface { + Replace(string) string + Set(key, value string) +} + +type replacer struct { + replacements map[string]string + emptyValue string +} + +// NewReplacer makes a new replacer based on r and rr. +// Do not create a new replacer until r and rr have all +// the needed values, because this function copies those +// values into the replacer. rr may be nil if it is not +// available. emptyValue should be the string that is used +// in place of empty string (can still be empty string). +func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { + context := Context{W: rr, Req: r} + rep := replacer{ + replacements: map[string]string{ + "{type}": context.Type(), + "{name}": context.Name(), + "{class}": context.Class(), + "{proto}": context.Proto(), + "{when}": func() string { + return time.Now().Format(timeFormat) + }(), + "{remote}": context.IP(), + "{port}": func() string { + p, _ := context.Port() + return p + }(), + }, + emptyValue: emptyValue, + } + if rr != nil { + rep.replacements["{rcode}"] = strconv.Itoa(rr.rcode) + rep.replacements["{size}"] = strconv.Itoa(rr.size) + rep.replacements["{latency}"] = time.Since(rr.start).String() + } + + return rep +} + +// Replace performs a replacement of values on s and returns +// the string with the replaced values. +func (r replacer) Replace(s string) string { + // Header replacements - these are case-insensitive, so we can't just use strings.Replace() + for strings.Contains(s, headerReplacer) { + idxStart := strings.Index(s, headerReplacer) + endOffset := idxStart + len(headerReplacer) + idxEnd := strings.Index(s[endOffset:], "}") + if idxEnd > -1 { + placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1]) + replacement := r.replacements[placeholder] + if replacement == "" { + replacement = r.emptyValue + } + s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:] + } else { + break + } + } + + // Regular replacements - these are easier because they're case-sensitive + for placeholder, replacement := range r.replacements { + if replacement == "" { + replacement = r.emptyValue + } + s = strings.Replace(s, placeholder, replacement, -1) + } + + return s +} + +// Set sets key to value in the replacements map. +func (r replacer) Set(key, value string) { + r.replacements["{"+key+"}"] = value +} + +const ( + timeFormat = "02/Jan/2006:15:04:05 -0700" + headerReplacer = "{>" +) diff --git a/middleware/replacer_test.go b/middleware/replacer_test.go new file mode 100644 index 000000000..d98bd2de1 --- /dev/null +++ b/middleware/replacer_test.go @@ -0,0 +1,124 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNewReplacer(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + reader := strings.NewReader(`{"username": "dennis"}`) + + request, err := http.NewRequest("POST", "http://localhost", reader) + if err != nil { + t.Fatal("Request Formation Failed\n") + } + replaceValues := NewReplacer(request, recordRequest, "") + + switch v := replaceValues.(type) { + case replacer: + + if v.replacements["{host}"] != "localhost" { + t.Error("Expected host to be localhost") + } + if v.replacements["{method}"] != "POST" { + t.Error("Expected request method to be POST") + } + if v.replacements["{status}"] != "200" { + t.Error("Expected status to be 200") + } + + default: + t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n") + } +} + +func TestReplace(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + reader := strings.NewReader(`{"username": "dennis"}`) + + request, err := http.NewRequest("POST", "http://localhost", reader) + if err != nil { + t.Fatal("Request Formation Failed\n") + } + request.Header.Set("Custom", "foobarbaz") + request.Header.Set("ShorterVal", "1") + repl := NewReplacer(request, recordRequest, "-") + + if expected, actual := "This host is localhost.", repl.Replace("This host is {host}."); expected != actual { + t.Errorf("{host} replacement: expected '%s', got '%s'", expected, actual) + } + if expected, actual := "This request method is POST.", repl.Replace("This request method is {method}."); expected != actual { + t.Errorf("{method} replacement: expected '%s', got '%s'", expected, actual) + } + if expected, actual := "The response status is 200.", repl.Replace("The response status is {status}."); expected != actual { + t.Errorf("{status} replacement: expected '%s', got '%s'", expected, actual) + } + if expected, actual := "The Custom header is foobarbaz.", repl.Replace("The Custom header is {>Custom}."); expected != actual { + t.Errorf("{>Custom} replacement: expected '%s', got '%s'", expected, actual) + } + + // Test header case-insensitivity + if expected, actual := "The cUsToM header is foobarbaz...", repl.Replace("The cUsToM header is {>cUsToM}..."); expected != actual { + t.Errorf("{>cUsToM} replacement: expected '%s', got '%s'", expected, actual) + } + + // Test non-existent header/value + if expected, actual := "The Non-Existent header is -.", repl.Replace("The Non-Existent header is {>Non-Existent}."); expected != actual { + t.Errorf("{>Non-Existent} replacement: expected '%s', got '%s'", expected, actual) + } + + // Test bad placeholder + if expected, actual := "Bad {host placeholder...", repl.Replace("Bad {host placeholder..."); expected != actual { + t.Errorf("bad placeholder: expected '%s', got '%s'", expected, actual) + } + + // Test bad header placeholder + if expected, actual := "Bad {>Custom placeholder", repl.Replace("Bad {>Custom placeholder"); expected != actual { + t.Errorf("bad header placeholder: expected '%s', got '%s'", expected, actual) + } + + // Test bad header placeholder with valid one later + if expected, actual := "Bad -", repl.Replace("Bad {>Custom placeholder {>ShorterVal}"); expected != actual { + t.Errorf("bad header placeholders: expected '%s', got '%s'", expected, actual) + } + + // Test shorter header value with multiple placeholders + if expected, actual := "Short value 1 then foobarbaz.", repl.Replace("Short value {>ShorterVal} then {>Custom}."); expected != actual { + t.Errorf("short value: expected '%s', got '%s'", expected, actual) + } +} + +func TestSet(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + reader := strings.NewReader(`{"username": "dennis"}`) + + request, err := http.NewRequest("POST", "http://localhost", reader) + if err != nil { + t.Fatalf("Request Formation Failed \n") + } + repl := NewReplacer(request, recordRequest, "") + + repl.Set("host", "getcaddy.com") + repl.Set("method", "GET") + repl.Set("status", "201") + repl.Set("variable", "value") + + if repl.Replace("This host is {host}") != "This host is getcaddy.com" { + t.Error("Expected host replacement failed") + } + if repl.Replace("This request method is {method}") != "This request method is GET" { + t.Error("Expected method replacement failed") + } + if repl.Replace("The response status is {status}") != "The response status is 201" { + t.Error("Expected status replacement failed") + } + if repl.Replace("The value of variable is {variable}") != "The value of variable is value" { + t.Error("Expected variable replacement failed") + } +} diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go new file mode 100644 index 000000000..ddd4c38b1 --- /dev/null +++ b/middleware/rewrite/condition.go @@ -0,0 +1,130 @@ +package rewrite + +import ( + "fmt" + "regexp" + "strings" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// Operators +const ( + Is = "is" + Not = "not" + Has = "has" + NotHas = "not_has" + StartsWith = "starts_with" + EndsWith = "ends_with" + Match = "match" + NotMatch = "not_match" +) + +func operatorError(operator string) error { + return fmt.Errorf("Invalid operator %v", operator) +} + +func newReplacer(r *dns.Msg) middleware.Replacer { + return middleware.NewReplacer(r, nil, "") +} + +// condition is a rewrite condition. +type condition func(string, string) bool + +var conditions = map[string]condition{ + Is: isFunc, + Not: notFunc, + Has: hasFunc, + NotHas: notHasFunc, + StartsWith: startsWithFunc, + EndsWith: endsWithFunc, + Match: matchFunc, + NotMatch: notMatchFunc, +} + +// isFunc is condition for Is operator. +// It checks for equality. +func isFunc(a, b string) bool { + return a == b +} + +// notFunc is condition for Not operator. +// It checks for inequality. +func notFunc(a, b string) bool { + return a != b +} + +// hasFunc is condition for Has operator. +// It checks if b is a substring of a. +func hasFunc(a, b string) bool { + return strings.Contains(a, b) +} + +// notHasFunc is condition for NotHas operator. +// It checks if b is not a substring of a. +func notHasFunc(a, b string) bool { + return !strings.Contains(a, b) +} + +// startsWithFunc is condition for StartsWith operator. +// It checks if b is a prefix of a. +func startsWithFunc(a, b string) bool { + return strings.HasPrefix(a, b) +} + +// endsWithFunc is condition for EndsWith operator. +// It checks if b is a suffix of a. +func endsWithFunc(a, b string) bool { + return strings.HasSuffix(a, b) +} + +// matchFunc is condition for Match operator. +// It does regexp matching of a against pattern in b +// and returns if they match. +func matchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return matched +} + +// notMatchFunc is condition for NotMatch operator. +// It does regexp matching of a against pattern in b +// and returns if they do not match. +func notMatchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return !matched +} + +// If is statement for a rewrite condition. +type If struct { + A string + Operator string + B string +} + +// True returns true if the condition is true and false otherwise. +// If r is not nil, it replaces placeholders before comparison. +func (i If) True(r *dns.Msg) bool { + if c, ok := conditions[i.Operator]; ok { + a, b := i.A, i.B + if r != nil { + replacer := newReplacer(r) + a = replacer.Replace(i.A) + b = replacer.Replace(i.B) + } + return c(a, b) + } + return false +} + +// NewIf creates a new If condition. +func NewIf(a, operator, b string) (If, error) { + if _, ok := conditions[operator]; !ok { + return If{}, operatorError(operator) + } + return If{ + A: a, + Operator: operator, + B: b, + }, nil +} diff --git a/middleware/rewrite/condition_test.go b/middleware/rewrite/condition_test.go new file mode 100644 index 000000000..3c3b6053a --- /dev/null +++ b/middleware/rewrite/condition_test.go @@ -0,0 +1,106 @@ +package rewrite + +import ( + "net/http" + "strings" + "testing" +) + +func TestConditions(t *testing.T) { + tests := []struct { + condition string + isTrue bool + }{ + {"a is b", false}, + {"a is a", true}, + {"a not b", true}, + {"a not a", false}, + {"a has a", true}, + {"a has b", false}, + {"ba has b", true}, + {"bab has b", true}, + {"bab has bb", false}, + {"a not_has a", false}, + {"a not_has b", true}, + {"ba not_has b", false}, + {"bab not_has b", false}, + {"bab not_has bb", true}, + {"bab starts_with bb", false}, + {"bab starts_with ba", true}, + {"bab starts_with bab", true}, + {"bab ends_with bb", false}, + {"bab ends_with bab", true}, + {"bab ends_with ab", true}, + {"a match *", false}, + {"a match a", true}, + {"a match .*", true}, + {"a match a.*", true}, + {"a match b.*", false}, + {"ba match b.*", true}, + {"ba match b[a-z]", true}, + {"b0 match b[a-z]", false}, + {"b0a match b[a-z]", false}, + {"b0a match b[a-z]+", false}, + {"b0a match b[a-z0-9]+", true}, + {"a not_match *", true}, + {"a not_match a", false}, + {"a not_match .*", false}, + {"a not_match a.*", false}, + {"a not_match b.*", true}, + {"ba not_match b.*", false}, + {"ba not_match b[a-z]", false}, + {"b0 not_match b[a-z]", true}, + {"b0a not_match b[a-z]", true}, + {"b0a not_match b[a-z]+", true}, + {"b0a not_match b[a-z0-9]+", false}, + } + + for i, test := range tests { + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(nil) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } + + invalidOperators := []string{"ss", "and", "if"} + for _, op := range invalidOperators { + _, err := NewIf("a", op, "b") + if err == nil { + t.Errorf("Invalid operator %v used, expected error.", op) + } + } + + replaceTests := []struct { + url string + condition string + isTrue bool + }{ + {"/home", "{uri} match /home", true}, + {"/hom", "{uri} match /home", false}, + {"/hom", "{uri} starts_with /home", false}, + {"/hom", "{uri} starts_with /h", true}, + {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true}, + {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true}, + } + + for i, test := range replaceTests { + r, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Error(err) + } + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(r) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } +} diff --git a/middleware/rewrite/reverter.go b/middleware/rewrite/reverter.go new file mode 100644 index 000000000..c3425866e --- /dev/null +++ b/middleware/rewrite/reverter.go @@ -0,0 +1,38 @@ +package rewrite + +import "github.com/miekg/dns" + +// ResponseRevert reverses the operations done on the question section of a packet. +// This is need because the client will otherwise disregards the response, i.e. +// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN' +type ResponseReverter struct { + dns.ResponseWriter + original dns.Question +} + +func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter { + return &ResponseReverter{ + ResponseWriter: w, + original: r.Question[0], + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *ResponseReverter) WriteMsg(res *dns.Msg) error { + res.Question[0] = r.original + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *ResponseReverter) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + return n, err +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *ResponseReverter) Hijack() { + r.ResponseWriter.Hijack() + return +} diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go new file mode 100644 index 000000000..b3039615b --- /dev/null +++ b/middleware/rewrite/rewrite.go @@ -0,0 +1,223 @@ +// Package rewrite is middleware for rewriting requests internally to +// something different. +package rewrite + +import ( + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// Result is the result of a rewrite +type Result int + +const ( + // RewriteIgnored is returned when rewrite is not done on request. + RewriteIgnored Result = iota + // RewriteDone is returned when rewrite is done on request. + RewriteDone + // RewriteStatus is returned when rewrite is not needed and status code should be set + // for the request. + RewriteStatus +) + +// Rewrite is middleware to rewrite requests internally before being handled. +type Rewrite struct { + Next middleware.Handler + Rules []Rule +} + +// ServeHTTP implements the middleware.Handler interface. +func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + wr := NewResponseReverter(w, r) + for _, rule := range rw.Rules { + switch result := rule.Rewrite(r); result { + case RewriteDone: + return rw.Next.ServeDNS(wr, r) + case RewriteIgnored: + break + case RewriteStatus: + // only valid for complex rules. + // if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 { + // return cRule.Status, nil + // } + } + } + return rw.Next.ServeDNS(w, r) +} + +// Rule describes an internal location rewrite rule. +type Rule interface { + // Rewrite rewrites the internal location of the current request. + Rewrite(*dns.Msg) Result +} + +// SimpleRule is a simple rewrite rule. If the From and To look like a type +// the type of the request is rewritten, otherwise the name is. +// Note: TSIG signed requests will be invalid. +type SimpleRule struct { + From, To string + fromType, toType uint16 +} + +// NewSimpleRule creates a new Simple Rule +func NewSimpleRule(from, to string) SimpleRule { + tpf := dns.StringToType[from] + tpt := dns.StringToType[to] + + return SimpleRule{From: from, To: to, fromType: tpf, toType: tpt} +} + +// Rewrite rewrites the the current request. +func (s SimpleRule) Rewrite(r *dns.Msg) Result { + if s.fromType > 0 && s.toType > 0 { + if r.Question[0].Qtype == s.fromType { + r.Question[0].Qtype = s.toType + return RewriteDone + } + + } + + // if the question name matches the full name, or subset rewrite that + // s.Question[0].Name + return RewriteIgnored +} + +/* +// ComplexRule is a rewrite rule based on a regular expression +type ComplexRule struct { + // Path base. Request to this path and subpaths will be rewritten + Base string + + // Path to rewrite to + To string + + // If set, neither performs rewrite nor proceeds + // with request. Only returns code. + Status int + + // Extensions to filter by + Exts []string + + // Rewrite conditions + Ifs []If + + *regexp.Regexp +} + +// NewComplexRule creates a new RegexpRule. It returns an error if regexp +// pattern (pattern) or extensions (ext) are invalid. +func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { + // validate regexp if present + var r *regexp.Regexp + if pattern != "" { + var err error + r, err = regexp.Compile(pattern) + if err != nil { + return nil, err + } + } + + // validate extensions if present + for _, v := range ext { + if len(v) < 2 || (len(v) < 3 && v[0] == '!') { + // check if no extension is specified + if v != "/" && v != "!/" { + return nil, fmt.Errorf("invalid extension %v", v) + } + } + } + + return &ComplexRule{ + Base: base, + To: to, + Status: status, + Exts: ext, + Ifs: ifs, + Regexp: r, + }, nil +} + +// Rewrite rewrites the internal location of the current request. +func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) { + rPath := req.URL.Path + replacer := newReplacer(req) + + // validate base + if !middleware.Path(rPath).Matches(r.Base) { + return + } + + // validate extensions + if !r.matchExt(rPath) { + return + } + + // validate regexp if present + if r.Regexp != nil { + // include trailing slash in regexp if present + start := len(r.Base) + if strings.HasSuffix(r.Base, "/") { + start-- + } + + matches := r.FindStringSubmatch(rPath[start:]) + switch len(matches) { + case 0: + // no match + return + default: + // set regexp match variables {1}, {2} ... + for i := 1; i < len(matches); i++ { + replacer.Set(fmt.Sprint(i), matches[i]) + } + } + } + + // validate rewrite conditions + for _, i := range r.Ifs { + if !i.True(req) { + return + } + } + + // if status is present, stop rewrite and return it. + if r.Status != 0 { + return RewriteStatus + } + + // attempt rewrite + return To(fs, req, r.To, replacer) +} + +// matchExt matches rPath against registered file extensions. +// Returns true if a match is found and false otherwise. +func (r *ComplexRule) matchExt(rPath string) bool { + f := filepath.Base(rPath) + ext := path.Ext(f) + if ext == "" { + ext = "/" + } + + mustUse := false + for _, v := range r.Exts { + use := true + if v[0] == '!' { + use = false + v = v[1:] + } + + if use { + mustUse = true + } + + if ext == v { + return use + } + } + + if mustUse { + return false + } + return true +} +*/ diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go new file mode 100644 index 000000000..f57dfd602 --- /dev/null +++ b/middleware/rewrite/rewrite_test.go @@ -0,0 +1,159 @@ +package rewrite + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/miekg/coredns/middleware" +) + +func TestRewrite(t *testing.T) { + rw := Rewrite{ + Next: middleware.HandlerFunc(urlPrinter), + Rules: []Rule{ + NewSimpleRule("/from", "/to"), + NewSimpleRule("/a", "/b"), + NewSimpleRule("/b", "/b{uri}"), + }, + FileSys: http.Dir("."), + } + + regexps := [][]string{ + {"/reg/", ".*", "/to", ""}, + {"/r/", "[a-z]+", "/toaz", "!.html|"}, + {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, + {"/ab/", "ab", "/ab?{query}", ".txt|"}, + {"/ab/", "ab", "/ab?type=html&{query}", ".html|"}, + {"/abc/", "ab", "/abc/{file}", ".html|"}, + {"/abcd/", "ab", "/a/{dir}/{file}", ".html|"}, + {"/abcde/", "ab", "/a#{fragment}", ".html|"}, + {"/ab/", `.*\.jpg`, "/ajpg", ""}, + {"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""}, + {"/reg2grp", `(.*)`, "/{1}", ""}, + {"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""}, + } + + for _, regexpRule := range regexps { + var ext []string + if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { + ext = s[:len(s)-1] + } + rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil) + if err != nil { + t.Fatal(err) + } + rw.Rules = append(rw.Rules, rule) + } + + tests := []struct { + from string + expectedTo string + }{ + {"/from", "/to"}, + {"/a", "/b"}, + {"/b", "/b/b"}, + {"/aa", "/aa"}, + {"/", "/"}, + {"/a?foo=bar", "/b?foo=bar"}, + {"/asdf?foo=bar", "/asdf?foo=bar"}, + {"/foo#bar", "/foo#bar"}, + {"/a#foo", "/b#foo"}, + {"/reg/foo", "/to"}, + {"/re", "/re"}, + {"/r/", "/r/"}, + {"/r/123", "/r/123"}, + {"/r/a123", "/toaz"}, + {"/r/abcz", "/toaz"}, + {"/r/z", "/toaz"}, + {"/r/z.html", "/r/z.html"}, + {"/r/z.js", "/toaz"}, + {"/url/asAB", "/to/url/asAB"}, + {"/url/aBsAB", "/url/aBsAB"}, + {"/url/a00sAB", "/to/url/a00sAB"}, + {"/url/a0z0sAB", "/to/url/a0z0sAB"}, + {"/ab/aa", "/ab/aa"}, + {"/ab/ab", "/ab/ab"}, + {"/ab/ab.txt", "/ab"}, + {"/ab/ab.txt?name=name", "/ab?name=name"}, + {"/ab/ab.html?name=name", "/ab?type=html&name=name"}, + {"/abc/ab.html", "/abc/ab.html"}, + {"/abcd/abcd.html", "/a/abcd/abcd.html"}, + {"/abcde/abcde.html", "/a"}, + {"/abcde/abcde.html#1234", "/a#1234"}, + {"/ab/ab.jpg", "/ajpg"}, + {"/reggrp/ad/12", "/a12"}, + {"/reggrp/ad/124a", "/a124/a"}, + {"/reggrp/ad/124abc", "/a124/abc"}, + {"/reg2grp/ad/124abc", "/ad/124abc"}, + {"/reg3grp/ad/aa/66", "/adaa66"}, + {"/reg3grp/ad612/n1n/ab", "/ad612n1nab"}, + } + + for i, test := range tests { + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + rw.ServeHTTP(rec, req) + + if rec.Body.String() != test.expectedTo { + t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", + i, test.expectedTo, rec.Body.String()) + } + } + + statusTests := []struct { + status int + base string + to string + regexp string + statusExpected bool + }{ + {400, "/status", "", "", true}, + {400, "/ignore", "", "", false}, + {400, "/", "", "^/ignore", false}, + {400, "/", "", "(.*)", true}, + {400, "/status", "", "", true}, + } + + for i, s := range statusTests { + urlPath := fmt.Sprintf("/status%d", i) + rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil) + if err != nil { + t.Fatalf("Test %d: No error expected for rule but found %v", i, err) + } + rw.Rules = []Rule{rule} + req, err := http.NewRequest("GET", urlPath, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + code, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: No error expected for handler but found %v", i, err) + } + if s.statusExpected { + if rec.Body.String() != "" { + t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String()) + } + if code != s.status { + t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code) + } + } else { + if code != 0 { + t.Errorf("Test %d: Expected no status code found %d", i, code) + } + } + } +} + +func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprintf(w, r.URL.String()) + return 0, nil +} diff --git a/middleware/rewrite/testdata/testdir/empty b/middleware/rewrite/testdata/testdir/empty new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/middleware/rewrite/testdata/testdir/empty diff --git a/middleware/rewrite/testdata/testfile b/middleware/rewrite/testdata/testfile new file mode 100644 index 000000000..7b4d68d70 --- /dev/null +++ b/middleware/rewrite/testdata/testfile @@ -0,0 +1 @@ +empty
\ No newline at end of file diff --git a/middleware/roller.go b/middleware/roller.go new file mode 100644 index 000000000..995cabf91 --- /dev/null +++ b/middleware/roller.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "io" + + "gopkg.in/natefinch/lumberjack.v2" +) + +// LogRoller implements a middleware that provides a rolling logger. +type LogRoller struct { + Filename string + MaxSize int + MaxAge int + MaxBackups int + LocalTime bool +} + +// GetLogWriter returns an io.Writer that writes to a rolling logger. +func (l LogRoller) GetLogWriter() io.Writer { + return &lumberjack.Logger{ + Filename: l.Filename, + MaxSize: l.MaxSize, + MaxAge: l.MaxAge, + MaxBackups: l.MaxBackups, + LocalTime: l.LocalTime, + } +} diff --git a/middleware/zone.go b/middleware/zone.go new file mode 100644 index 000000000..6798bca8e --- /dev/null +++ b/middleware/zone.go @@ -0,0 +1,21 @@ +package middleware + +import "strings" + +type Zones []string + +// Matches checks to see if other matches p. +// The match will return the most specific zones +// that matches other. The empty string signals a not found +// condition. +func (z Zones) Matches(qname string) string { + zone := "" + for _, zname := range z { + if strings.HasSuffix(qname, zname) { + if len(zname) > len(zone) { + zone = zname + } + } + } + return zone +} |