aboutsummaryrefslogtreecommitdiff
path: root/middleware
diff options
context:
space:
mode:
Diffstat (limited to 'middleware')
-rw-r--r--middleware/commands.go120
-rw-r--r--middleware/commands_test.go291
-rw-r--r--middleware/context.go135
-rw-r--r--middleware/context_test.go613
-rw-r--r--middleware/errors/errors.go100
-rw-r--r--middleware/errors/errors_test.go168
-rw-r--r--middleware/etcd/TODO0
-rw-r--r--middleware/exchange.go10
-rw-r--r--middleware/file/file.go89
-rw-r--r--middleware/file/file_test.go325
-rw-r--r--middleware/host.go22
-rw-r--r--middleware/log/log.go66
-rw-r--r--middleware/log/log_test.go48
-rw-r--r--middleware/middleware.go105
-rw-r--r--middleware/middleware_test.go108
-rw-r--r--middleware/path.go18
-rw-r--r--middleware/prometheus/handler.go31
-rw-r--r--middleware/prometheus/metrics.go80
-rw-r--r--middleware/proxy/policy.go101
-rw-r--r--middleware/proxy/policy_test.go87
-rw-r--r--middleware/proxy/proxy.go120
-rw-r--r--middleware/proxy/proxy_test.go317
-rw-r--r--middleware/proxy/reverseproxy.go36
-rw-r--r--middleware/proxy/upstream.go235
-rw-r--r--middleware/proxy/upstream_test.go83
-rw-r--r--middleware/recorder.go70
-rw-r--r--middleware/recorder_test.go32
-rw-r--r--middleware/reflect/reflect.go84
-rw-r--r--middleware/reflect/reflect_test.go1
-rw-r--r--middleware/replacer.go98
-rw-r--r--middleware/replacer_test.go124
-rw-r--r--middleware/rewrite/condition.go130
-rw-r--r--middleware/rewrite/condition_test.go106
-rw-r--r--middleware/rewrite/reverter.go38
-rw-r--r--middleware/rewrite/rewrite.go223
-rw-r--r--middleware/rewrite/rewrite_test.go159
-rw-r--r--middleware/rewrite/testdata/testdir/empty0
-rw-r--r--middleware/rewrite/testdata/testfile1
-rw-r--r--middleware/roller.go27
-rw-r--r--middleware/zone.go21
40 files changed, 4422 insertions, 0 deletions
diff --git a/middleware/commands.go b/middleware/commands.go
new file mode 100644
index 000000000..5c241161e
--- /dev/null
+++ b/middleware/commands.go
@@ -0,0 +1,120 @@
+package middleware
+
+import (
+ "errors"
+ "runtime"
+ "unicode"
+
+ "github.com/flynn/go-shlex"
+)
+
+var runtimeGoos = runtime.GOOS
+
+// SplitCommandAndArgs takes a command string and parses it
+// shell-style into the command and its separate arguments.
+func SplitCommandAndArgs(command string) (cmd string, args []string, err error) {
+ var parts []string
+
+ if runtimeGoos == "windows" {
+ parts = parseWindowsCommand(command) // parse it Windows-style
+ } else {
+ parts, err = parseUnixCommand(command) // parse it Unix-style
+ if err != nil {
+ err = errors.New("error parsing command: " + err.Error())
+ return
+ }
+ }
+
+ if len(parts) == 0 {
+ err = errors.New("no command contained in '" + command + "'")
+ return
+ }
+
+ cmd = parts[0]
+ if len(parts) > 1 {
+ args = parts[1:]
+ }
+
+ return
+}
+
+// parseUnixCommand parses a unix style command line and returns the
+// command and its arguments or an error
+func parseUnixCommand(cmd string) ([]string, error) {
+ return shlex.Split(cmd)
+}
+
+// parseWindowsCommand parses windows command lines and
+// returns the command and the arguments as an array. It
+// should be able to parse commonly used command lines.
+// Only basic syntax is supported:
+// - spaces in double quotes are not token delimiters
+// - double quotes are escaped by either backspace or another double quote
+// - except for the above case backspaces are path separators (not special)
+//
+// Many sources point out that escaping quotes using backslash can be unsafe.
+// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 )
+//
+// This function has to be used on Windows instead
+// of the shlex package because this function treats backslash
+// characters properly.
+func parseWindowsCommand(cmd string) []string {
+ const backslash = '\\'
+ const quote = '"'
+
+ var parts []string
+ var part string
+ var inQuotes bool
+ var lastRune rune
+
+ for i, ch := range cmd {
+
+ if i != 0 {
+ lastRune = rune(cmd[i-1])
+ }
+
+ if ch == backslash {
+ // put it in the part - for now we don't know if it's an
+ // escaping char or path separator
+ part += string(ch)
+ continue
+ }
+
+ if ch == quote {
+ if lastRune == backslash {
+ // remove the backslash from the part and add the escaped quote instead
+ part = part[:len(part)-1]
+ part += string(ch)
+ continue
+ }
+
+ if lastRune == quote {
+ // revert the last change of the inQuotes state
+ // it was an escaping quote
+ inQuotes = !inQuotes
+ part += string(ch)
+ continue
+ }
+
+ // normal escaping quotes
+ inQuotes = !inQuotes
+ continue
+
+ }
+
+ if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 {
+ parts = append(parts, part)
+ part = ""
+ continue
+ }
+
+ part += string(ch)
+ }
+
+ if len(part) > 0 {
+ parts = append(parts, part)
+ part = ""
+ }
+
+ return parts
+}
diff --git a/middleware/commands_test.go b/middleware/commands_test.go
new file mode 100644
index 000000000..3001e65a5
--- /dev/null
+++ b/middleware/commands_test.go
@@ -0,0 +1,291 @@
+package middleware
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func TestParseUnixCommand(t *testing.T) {
+ tests := []struct {
+ input string
+ expected []string
+ }{
+ // 0 - emtpy command
+ {
+ input: ``,
+ expected: []string{},
+ },
+ // 1 - command without arguments
+ {
+ input: `command`,
+ expected: []string{`command`},
+ },
+ // 2 - command with single argument
+ {
+ input: `command arg1`,
+ expected: []string{`command`, `arg1`},
+ },
+ // 3 - command with multiple arguments
+ {
+ input: `command arg1 arg2`,
+ expected: []string{`command`, `arg1`, `arg2`},
+ },
+ // 4 - command with single argument with space character - in quotes
+ {
+ input: `command "arg1 arg1"`,
+ expected: []string{`command`, `arg1 arg1`},
+ },
+ // 5 - command with multiple spaces and tab character
+ {
+ input: "command arg1 arg2\targ3",
+ expected: []string{`command`, `arg1`, `arg2`, `arg3`},
+ },
+ // 6 - command with single argument with space character - escaped with backspace
+ {
+ input: `command arg1\ arg2`,
+ expected: []string{`command`, `arg1 arg2`},
+ },
+ // 7 - single quotes should escape special chars
+ {
+ input: `command 'arg1\ arg2'`,
+ expected: []string{`command`, `arg1\ arg2`},
+ },
+ }
+
+ for i, test := range tests {
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
+ actual, _ := parseUnixCommand(test.input)
+ if len(actual) != len(test.expected) {
+ t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
+ continue
+ }
+ for j := 0; j < len(actual); j++ {
+ if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
+ t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
+ }
+ }
+ }
+}
+
+func TestParseWindowsCommand(t *testing.T) {
+ tests := []struct {
+ input string
+ expected []string
+ }{
+ { // 0 - empty command - do not fail
+ input: ``,
+ expected: []string{},
+ },
+ { // 1 - cmd without args
+ input: `cmd`,
+ expected: []string{`cmd`},
+ },
+ { // 2 - multiple args
+ input: `cmd arg1 arg2`,
+ expected: []string{`cmd`, `arg1`, `arg2`},
+ },
+ { // 3 - multiple args with space
+ input: `cmd "combined arg" arg2`,
+ expected: []string{`cmd`, `combined arg`, `arg2`},
+ },
+ { // 4 - path without spaces
+ input: `mkdir C:\Windows\foo\bar`,
+ expected: []string{`mkdir`, `C:\Windows\foo\bar`},
+ },
+ { // 5 - command with space in quotes
+ input: `"command here"`,
+ expected: []string{`command here`},
+ },
+ { // 6 - argument with escaped quotes (two quotes)
+ input: `cmd ""arg""`,
+ expected: []string{`cmd`, `"arg"`},
+ },
+ { // 7 - argument with escaped quotes (backslash)
+ input: `cmd \"arg\"`,
+ expected: []string{`cmd`, `"arg"`},
+ },
+ { // 8 - two quotes (escaped) inside an inQuote element
+ input: `cmd "a ""quoted value"`,
+ expected: []string{`cmd`, `a "quoted value`},
+ },
+ // TODO - see how many quotes are dislayed if we use "", """, """""""
+ { // 9 - two quotes outside an inQuote element
+ input: `cmd a ""quoted value`,
+ expected: []string{`cmd`, `a`, `"quoted`, `value`},
+ },
+ { // 10 - path with space in quotes
+ input: `mkdir "C:\directory name\foobar"`,
+ expected: []string{`mkdir`, `C:\directory name\foobar`},
+ },
+ { // 11 - space without quotes
+ input: `mkdir C:\ space`,
+ expected: []string{`mkdir`, `C:\`, `space`},
+ },
+ { // 12 - space in quotes
+ input: `mkdir "C:\ space"`,
+ expected: []string{`mkdir`, `C:\ space`},
+ },
+ { // 13 - UNC
+ input: `mkdir \\?\C:\Users`,
+ expected: []string{`mkdir`, `\\?\C:\Users`},
+ },
+ { // 14 - UNC with space
+ input: `mkdir "\\?\C:\Program Files"`,
+ expected: []string{`mkdir`, `\\?\C:\Program Files`},
+ },
+
+ { // 15 - unclosed quotes - treat as if the path ends with quote
+ input: `mkdir "c:\Program files`,
+ expected: []string{`mkdir`, `c:\Program files`},
+ },
+ { // 16 - quotes used inside the argument
+ input: `mkdir "c:\P"rogra"m f"iles`,
+ expected: []string{`mkdir`, `c:\Program files`},
+ },
+ }
+
+ for i, test := range tests {
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
+
+ actual := parseWindowsCommand(test.input)
+ if len(actual) != len(test.expected) {
+ t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
+ continue
+ }
+ for j := 0; j < len(actual); j++ {
+ if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
+ t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
+ }
+ }
+ }
+}
+
+func TestSplitCommandAndArgs(t *testing.T) {
+
+ // force linux parsing. It's more robust and covers error cases
+ runtimeGoos = "linux"
+ defer func() {
+ runtimeGoos = runtime.GOOS
+ }()
+
+ var parseErrorContent = "error parsing command:"
+ var noCommandErrContent = "no command contained in"
+
+ tests := []struct {
+ input string
+ expectedCommand string
+ expectedArgs []string
+ expectedErrContent string
+ }{
+ // 0 - emtpy command
+ {
+ input: ``,
+ expectedCommand: ``,
+ expectedArgs: nil,
+ expectedErrContent: noCommandErrContent,
+ },
+ // 1 - command without arguments
+ {
+ input: `command`,
+ expectedCommand: `command`,
+ expectedArgs: nil,
+ expectedErrContent: ``,
+ },
+ // 2 - command with single argument
+ {
+ input: `command arg1`,
+ expectedCommand: `command`,
+ expectedArgs: []string{`arg1`},
+ expectedErrContent: ``,
+ },
+ // 3 - command with multiple arguments
+ {
+ input: `command arg1 arg2`,
+ expectedCommand: `command`,
+ expectedArgs: []string{`arg1`, `arg2`},
+ expectedErrContent: ``,
+ },
+ // 4 - command with unclosed quotes
+ {
+ input: `command "arg1 arg2`,
+ expectedCommand: "",
+ expectedArgs: nil,
+ expectedErrContent: parseErrorContent,
+ },
+ // 5 - command with unclosed quotes
+ {
+ input: `command 'arg1 arg2"`,
+ expectedCommand: "",
+ expectedArgs: nil,
+ expectedErrContent: parseErrorContent,
+ },
+ }
+
+ for i, test := range tests {
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
+ actualCommand, actualArgs, actualErr := SplitCommandAndArgs(test.input)
+
+ // test if error matches expectation
+ if test.expectedErrContent != "" {
+ if actualErr == nil {
+ t.Errorf(errorPrefix+"Expected error with content [%s], found no error."+errorSuffix, test.expectedErrContent)
+ } else if !strings.Contains(actualErr.Error(), test.expectedErrContent) {
+ t.Errorf(errorPrefix+"Expected error with content [%s], found [%v]."+errorSuffix, test.expectedErrContent, actualErr)
+ }
+ } else if actualErr != nil {
+ t.Errorf(errorPrefix+"Expected no error, found [%v]."+errorSuffix, actualErr)
+ }
+
+ // test if command matches
+ if test.expectedCommand != actualCommand {
+ t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand)
+ }
+
+ // test if arguments match
+ if len(test.expectedArgs) != len(actualArgs) {
+ t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs)
+ } else {
+ // test args only if the count matches.
+ for j, actualArg := range actualArgs {
+ expectedArg := test.expectedArgs[j]
+ if actualArg != expectedArg {
+ t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg)
+ }
+ }
+ }
+ }
+}
+
+func ExampleSplitCommandAndArgs() {
+ var commandLine string
+ var command string
+ var args []string
+
+ // just for the test - change GOOS and reset it at the end of the test
+ runtimeGoos = "windows"
+ defer func() {
+ runtimeGoos = runtime.GOOS
+ }()
+
+ commandLine = `mkdir /P "C:\Program Files"`
+ command, args, _ = SplitCommandAndArgs(commandLine)
+
+ fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
+
+ // set GOOS to linux
+ runtimeGoos = "linux"
+
+ commandLine = `mkdir -p /path/with\ space`
+ command, args, _ = SplitCommandAndArgs(commandLine)
+
+ fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
+
+ // Output:
+ // Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files]
+ // Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space]
+}
diff --git a/middleware/context.go b/middleware/context.go
new file mode 100644
index 000000000..8868c1c03
--- /dev/null
+++ b/middleware/context.go
@@ -0,0 +1,135 @@
+package middleware
+
+import (
+ "net"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// This file contains the context and functions available for
+// use in the templates.
+
+// Context is the context with which Caddy templates are executed.
+type Context struct {
+ Root http.FileSystem // TODO(miek): needed
+ Req *dns.Msg
+ W dns.ResponseWriter
+}
+
+// Now returns the current timestamp in the specified format.
+func (c Context) Now(format string) string {
+ return time.Now().Format(format)
+}
+
+// NowDate returns the current date/time that can be used
+// in other time functions.
+func (c Context) NowDate() time.Time {
+ return time.Now()
+}
+
+// Header gets the value of a header.
+func (c Context) Header() *dns.RR_Header {
+ // TODO(miek)
+ return nil
+}
+
+// IP gets the (remote) IP address of the client making the request.
+func (c Context) IP() string {
+ ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String())
+ if err != nil {
+ return c.W.RemoteAddr().String()
+ }
+ return ip
+}
+
+// Post gets the (remote) Port of the client making the request.
+func (c Context) Port() (string, error) {
+ _, port, err := net.SplitHostPort(c.W.RemoteAddr().String())
+ if err != nil {
+ return "0", err
+ }
+ return port, nil
+}
+
+// Proto gets the protocol used as the transport. This
+// will be udp or tcp.
+func (c Context) Proto() string {
+ if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok {
+ return "udp"
+ }
+ if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok {
+ return "tcp"
+ }
+ return "udp"
+}
+
+// Family returns the family of the transport.
+// 1 for IPv4 and 2 for IPv6.
+func (c Context) Family() int {
+ var a net.IP
+ ip := c.W.RemoteAddr()
+ if i, ok := ip.(*net.UDPAddr); ok {
+ a = i.IP
+ }
+ if i, ok := ip.(*net.TCPAddr); ok {
+ a = i.IP
+ }
+
+ if a.To4() != nil {
+ return 1
+ }
+ return 2
+}
+
+// Type returns the type of the question as a string.
+func (c Context) Type() string {
+ return dns.Type(c.Req.Question[0].Qtype).String()
+}
+
+// QType returns the type of the question as a uint16.
+func (c Context) QType() uint16 {
+ return c.Req.Question[0].Qtype
+}
+
+// Name returns the name of the question in the request. Note
+// this name will always have a closing dot and will be lower cased.
+func (c Context) Name() string {
+ return strings.ToLower(dns.Name(c.Req.Question[0].Name).String())
+}
+
+// QName returns the name of the question in the request.
+func (c Context) QName() string {
+ return dns.Name(c.Req.Question[0].Name).String()
+}
+
+// Class returns the class of the question in the request.
+func (c Context) Class() string {
+ return dns.Class(c.Req.Question[0].Qclass).String()
+}
+
+// QClass returns the class of the question in the request.
+func (c Context) QClass() uint16 {
+ return c.Req.Question[0].Qclass
+}
+
+// More convience types for extracting stuff from a message?
+// Header?
+
+// ErrorMessage returns an error message suitable for sending
+// back to the client.
+func (c Context) ErrorMessage(rcode int) *dns.Msg {
+ m := new(dns.Msg)
+ m.SetRcode(c.Req, rcode)
+ return m
+}
+
+// AnswerMessage returns an error message suitable for sending
+// back to the client.
+func (c Context) AnswerMessage() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetReply(c.Req)
+ return m
+}
diff --git a/middleware/context_test.go b/middleware/context_test.go
new file mode 100644
index 000000000..689c47c13
--- /dev/null
+++ b/middleware/context_test.go
@@ -0,0 +1,613 @@
+package middleware
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestInclude(t *testing.T) {
+ context := getContextOrFail(t)
+
+ inputFilename := "test_file"
+ absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
+ defer func() {
+ err := os.Remove(absInFilePath)
+ if err != nil && !os.IsNotExist(err) {
+ t.Fatalf("Failed to clean test file!")
+ }
+ }()
+
+ tests := []struct {
+ fileContent string
+ expectedContent string
+ shouldErr bool
+ expectedErrorContent string
+ }{
+ // Test 0 - all good
+ {
+ fileContent: `str1 {{ .Root }} str2`,
+ expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
+ shouldErr: false,
+ expectedErrorContent: "",
+ },
+ // Test 1 - failure on template.Parse
+ {
+ fileContent: `str1 {{ .Root } str2`,
+ expectedContent: "",
+ shouldErr: true,
+ expectedErrorContent: `unexpected "}" in operand`,
+ },
+ // Test 3 - failure on template.Execute
+ {
+ fileContent: `str1 {{ .InvalidField }} str2`,
+ expectedContent: "",
+ shouldErr: true,
+ expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`,
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ // WriteFile truncates the contentt
+ err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
+ if err != nil {
+ t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
+ }
+
+ content, err := context.Include(inputFilename)
+ if err != nil {
+ if !test.shouldErr {
+ t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
+ }
+ if !strings.Contains(err.Error(), test.expectedErrorContent) {
+ t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
+ }
+ }
+
+ if err == nil && test.shouldErr {
+ t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
+ }
+
+ if content != test.expectedContent {
+ t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
+ }
+ }
+}
+
+func TestIncludeNotExisting(t *testing.T) {
+ context := getContextOrFail(t)
+
+ _, err := context.Include("not_existing")
+ if err == nil {
+ t.Errorf("Expected error but found nil!")
+ }
+}
+
+func TestMarkdown(t *testing.T) {
+ context := getContextOrFail(t)
+
+ inputFilename := "test_file"
+ absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
+ defer func() {
+ err := os.Remove(absInFilePath)
+ if err != nil && !os.IsNotExist(err) {
+ t.Fatalf("Failed to clean test file!")
+ }
+ }()
+
+ tests := []struct {
+ fileContent string
+ expectedContent string
+ }{
+ // Test 0 - test parsing of markdown
+ {
+ fileContent: "* str1\n* str2\n",
+ expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ // WriteFile truncates the contentt
+ err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
+ if err != nil {
+ t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
+ }
+
+ content, _ := context.Markdown(inputFilename)
+ if content != test.expectedContent {
+ t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
+ }
+ }
+}
+
+func TestCookie(t *testing.T) {
+
+ tests := []struct {
+ cookie *http.Cookie
+ cookieName string
+ expectedValue string
+ }{
+ // Test 0 - happy path
+ {
+ cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
+ cookieName: "cookieName",
+ expectedValue: "cookieValue",
+ },
+ // Test 1 - try to get a non-existing cookie
+ {
+ cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
+ cookieName: "notExisting",
+ expectedValue: "",
+ },
+ // Test 2 - partial name match
+ {
+ cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
+ cookieName: "cook",
+ expectedValue: "",
+ },
+ // Test 3 - cookie with optional fields
+ {
+ cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
+ cookieName: "cookie",
+ expectedValue: "cookieValue",
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ // reinitialize the context for each test
+ context := getContextOrFail(t)
+
+ context.Req.AddCookie(test.cookie)
+
+ actualCookieVal := context.Cookie(test.cookieName)
+
+ if actualCookieVal != test.expectedValue {
+ t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
+ }
+ }
+}
+
+func TestCookieMultipleCookies(t *testing.T) {
+ context := getContextOrFail(t)
+
+ cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
+
+ // make sure that there's no state and multiple requests for different cookies return the correct result
+ for i := 0; i < 10; i++ {
+ context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
+ }
+
+ for i := 0; i < 10; i++ {
+ expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
+ actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
+ if actualCookieVal != expectedCookieVal {
+ t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
+ }
+ }
+}
+
+func TestHeader(t *testing.T) {
+ context := getContextOrFail(t)
+
+ headerKey, headerVal := "Header1", "HeaderVal1"
+ context.Req.Header.Add(headerKey, headerVal)
+
+ actualHeaderVal := context.Header(headerKey)
+ if actualHeaderVal != headerVal {
+ t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
+ }
+
+ missingHeaderVal := context.Header("not-existing")
+ if missingHeaderVal != "" {
+ t.Errorf("Expected empty header value, found %s", missingHeaderVal)
+ }
+}
+
+func TestIP(t *testing.T) {
+ context := getContextOrFail(t)
+
+ tests := []struct {
+ inputRemoteAddr string
+ expectedIP string
+ }{
+ // Test 0 - ipv4 with port
+ {"1.1.1.1:1111", "1.1.1.1"},
+ // Test 1 - ipv4 without port
+ {"1.1.1.1", "1.1.1.1"},
+ // Test 2 - ipv6 with port
+ {"[::1]:11", "::1"},
+ // Test 3 - ipv6 without port and brackets
+ {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
+ // Test 4 - ipv6 with zone and port
+ {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ context.Req.RemoteAddr = test.inputRemoteAddr
+ actualIP := context.IP()
+
+ if actualIP != test.expectedIP {
+ t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
+ }
+ }
+}
+
+func TestURL(t *testing.T) {
+ context := getContextOrFail(t)
+
+ inputURL := "http://localhost"
+ context.Req.RequestURI = inputURL
+
+ if inputURL != context.URI() {
+ t.Errorf("Expected url %s, found %s", inputURL, context.URI())
+ }
+}
+
+func TestHost(t *testing.T) {
+ tests := []struct {
+ input string
+ expectedHost string
+ shouldErr bool
+ }{
+ {
+ input: "localhost:123",
+ expectedHost: "localhost",
+ shouldErr: false,
+ },
+ {
+ input: "localhost",
+ expectedHost: "localhost",
+ shouldErr: false,
+ },
+ {
+ input: "[::]",
+ expectedHost: "",
+ shouldErr: true,
+ },
+ }
+
+ for _, test := range tests {
+ testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
+ }
+}
+
+func TestPort(t *testing.T) {
+ tests := []struct {
+ input string
+ expectedPort string
+ shouldErr bool
+ }{
+ {
+ input: "localhost:123",
+ expectedPort: "123",
+ shouldErr: false,
+ },
+ {
+ input: "localhost",
+ expectedPort: "80", // assuming 80 is the default port
+ shouldErr: false,
+ },
+ {
+ input: ":8080",
+ expectedPort: "8080",
+ shouldErr: false,
+ },
+ {
+ input: "[::]",
+ expectedPort: "",
+ shouldErr: true,
+ },
+ }
+
+ for _, test := range tests {
+ testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
+ }
+}
+
+func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
+ context := getContextOrFail(t)
+
+ context.Req.Host = input
+ var actualResult, testedObject string
+ var err error
+
+ if isTestingHost {
+ actualResult, err = context.Host()
+ testedObject = "host"
+ } else {
+ actualResult, err = context.Port()
+ testedObject = "port"
+ }
+
+ if shouldErr && err == nil {
+ t.Errorf("Expected error, found nil!")
+ return
+ }
+
+ if !shouldErr && err != nil {
+ t.Errorf("Expected no error, found %s", err)
+ return
+ }
+
+ if actualResult != expectedResult {
+ t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
+ }
+}
+
+func TestMethod(t *testing.T) {
+ context := getContextOrFail(t)
+
+ method := "POST"
+ context.Req.Method = method
+
+ if method != context.Method() {
+ t.Errorf("Expected method %s, found %s", method, context.Method())
+ }
+
+}
+
+func TestPathMatches(t *testing.T) {
+ context := getContextOrFail(t)
+
+ tests := []struct {
+ urlStr string
+ pattern string
+ shouldMatch bool
+ }{
+ // Test 0
+ {
+ urlStr: "http://localhost/",
+ pattern: "",
+ shouldMatch: true,
+ },
+ // Test 1
+ {
+ urlStr: "http://localhost",
+ pattern: "",
+ shouldMatch: true,
+ },
+ // Test 1
+ {
+ urlStr: "http://localhost/",
+ pattern: "/",
+ shouldMatch: true,
+ },
+ // Test 3
+ {
+ urlStr: "http://localhost/?param=val",
+ pattern: "/",
+ shouldMatch: true,
+ },
+ // Test 4
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "/dir2",
+ shouldMatch: false,
+ },
+ // Test 5
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "/dir1",
+ shouldMatch: true,
+ },
+ // Test 6
+ {
+ urlStr: "http://localhost:444/dir1/dir2",
+ pattern: "/dir1",
+ shouldMatch: true,
+ },
+ // Test 7
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "*/dir2",
+ shouldMatch: false,
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+ var err error
+ context.Req.URL, err = url.Parse(test.urlStr)
+ if err != nil {
+ t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
+ }
+
+ matches := context.PathMatches(test.pattern)
+ if matches != test.shouldMatch {
+ t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
+ }
+ }
+}
+
+func TestTruncate(t *testing.T) {
+ context := getContextOrFail(t)
+ tests := []struct {
+ inputString string
+ inputLength int
+ expected string
+ }{
+ // Test 0 - small length
+ {
+ inputString: "string",
+ inputLength: 1,
+ expected: "s",
+ },
+ // Test 1 - exact length
+ {
+ inputString: "string",
+ inputLength: 6,
+ expected: "string",
+ },
+ // Test 2 - bigger length
+ {
+ inputString: "string",
+ inputLength: 10,
+ expected: "string",
+ },
+ // Test 3 - zero length
+ {
+ inputString: "string",
+ inputLength: 0,
+ expected: "",
+ },
+ // Test 4 - negative, smaller length
+ {
+ inputString: "string",
+ inputLength: -5,
+ expected: "tring",
+ },
+ // Test 5 - negative, exact length
+ {
+ inputString: "string",
+ inputLength: -6,
+ expected: "string",
+ },
+ // Test 6 - negative, bigger length
+ {
+ inputString: "string",
+ inputLength: -7,
+ expected: "string",
+ },
+ }
+
+ for i, test := range tests {
+ actual := context.Truncate(test.inputString, test.inputLength)
+ if actual != test.expected {
+ t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
+ }
+ }
+}
+
+func TestStripHTML(t *testing.T) {
+ context := getContextOrFail(t)
+ tests := []struct {
+ input string
+ expected string
+ }{
+ // Test 0 - no tags
+ {
+ input: `h1`,
+ expected: `h1`,
+ },
+ // Test 1 - happy path
+ {
+ input: `<h1>h1</h1>`,
+ expected: `h1`,
+ },
+ // Test 2 - tag in quotes
+ {
+ input: `<h1">">h1</h1>`,
+ expected: `h1`,
+ },
+ // Test 3 - multiple tags
+ {
+ input: `<h1><b>h1</b></h1>`,
+ expected: `h1`,
+ },
+ // Test 4 - tags not closed
+ {
+ input: `<h1`,
+ expected: `<h1`,
+ },
+ // Test 5 - false start
+ {
+ input: `<h1<b>hi`,
+ expected: `<h1hi`,
+ },
+ }
+
+ for i, test := range tests {
+ actual := context.StripHTML(test.input)
+ if actual != test.expected {
+ t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input)
+ }
+ }
+}
+
+func TestStripExt(t *testing.T) {
+ context := getContextOrFail(t)
+ tests := []struct {
+ input string
+ expected string
+ }{
+ // Test 0 - empty input
+ {
+ input: "",
+ expected: "",
+ },
+ // Test 1 - relative file with ext
+ {
+ input: "file.ext",
+ expected: "file",
+ },
+ // Test 2 - relative file without ext
+ {
+ input: "file",
+ expected: "file",
+ },
+ // Test 3 - absolute file without ext
+ {
+ input: "/file",
+ expected: "/file",
+ },
+ // Test 4 - absolute file with ext
+ {
+ input: "/file.ext",
+ expected: "/file",
+ },
+ // Test 5 - with ext but ends with /
+ {
+ input: "/dir.ext/",
+ expected: "/dir.ext/",
+ },
+ // Test 6 - file with ext under dir with ext
+ {
+ input: "/dir.ext/file.ext",
+ expected: "/dir.ext/file",
+ },
+ }
+
+ for i, test := range tests {
+ actual := context.StripExt(test.input)
+ if actual != test.expected {
+ t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input)
+ }
+ }
+}
+
+func initTestContext() (Context, error) {
+ body := bytes.NewBufferString("request body")
+ request, err := http.NewRequest("GET", "https://localhost", body)
+ if err != nil {
+ return Context{}, err
+ }
+
+ return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
+}
+
+func getContextOrFail(t *testing.T) Context {
+ context, err := initTestContext()
+ if err != nil {
+ t.Fatalf("Failed to prepare test context")
+ }
+ return context
+}
+
+func getTestPrefix(testN int) string {
+ return fmt.Sprintf("Test [%d]: ", testN)
+}
diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go
new file mode 100644
index 000000000..bf5bc7aae
--- /dev/null
+++ b/middleware/errors/errors.go
@@ -0,0 +1,100 @@
+// Package errors implements an HTTP error handling middleware.
+package errors
+
+import (
+ "fmt"
+ "log"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// ErrorHandler handles DNS errors (and errors from other middleware).
+type ErrorHandler struct {
+ Next middleware.Handler
+ LogFile string
+ Log *log.Logger
+ LogRoller *middleware.LogRoller
+ Debug bool // if true, errors are written out to client rather than to a log
+}
+
+func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ defer h.recovery(w, r)
+
+ rcode, err := h.Next.ServeDNS(w, r)
+
+ if err != nil {
+ errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
+
+ if h.Debug {
+ // Write error to response as a txt message instead of to log
+ answer := debugMsg(rcode, r)
+ txt, _ := dns.NewRR(". IN 0 TXT " + errMsg)
+ answer.Answer = append(answer.Answer, txt)
+ w.WriteMsg(answer)
+ return 0, err
+ }
+ h.Log.Println(errMsg)
+ }
+
+ return rcode, err
+}
+
+func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) {
+ rec := recover()
+ if rec == nil {
+ return
+ }
+
+ // Obtain source of panic
+ // From: https://gist.github.com/swdunlop/9629168
+ var name, file string // function name, file name
+ var line int
+ var pc [16]uintptr
+ n := runtime.Callers(3, pc[:])
+ for _, pc := range pc[:n] {
+ fn := runtime.FuncForPC(pc)
+ if fn == nil {
+ continue
+ }
+ file, line = fn.FileLine(pc)
+ name = fn.Name()
+ if !strings.HasPrefix(name, "runtime.") {
+ break
+ }
+ }
+
+ // Trim file path
+ delim := "/coredns/"
+ pkgPathPos := strings.Index(file, delim)
+ if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) {
+ file = file[pkgPathPos+len(delim):]
+ }
+
+ panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec)
+ if h.Debug {
+ // Write error and stack trace to the response rather than to a log
+ var stackBuf [4096]byte
+ stack := stackBuf[:runtime.Stack(stackBuf[:], false)]
+ answer := debugMsg(dns.RcodeServerFailure, r)
+ // add stack buf in TXT, limited to 255 chars for now.
+ txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255]))
+ answer.Answer = append(answer.Answer, txt)
+ w.WriteMsg(answer)
+ } else {
+ // Currently we don't use the function name, since file:line is more conventional
+ h.Log.Printf(panicMsg)
+ }
+}
+
+// debugMsg creates a debug message that gets send back to the client.
+func debugMsg(rcode int, r *dns.Msg) *dns.Msg {
+ answer := new(dns.Msg)
+ answer.SetRcode(r, rcode)
+ return answer
+}
+
+const timeFormat = "02/Jan/2006:15:04:05 -0700"
diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go
new file mode 100644
index 000000000..4434e835c
--- /dev/null
+++ b/middleware/errors/errors_test.go
@@ -0,0 +1,168 @@
+package errors
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+func TestErrors(t *testing.T) {
+ // create a temporary page
+ path := filepath.Join(os.TempDir(), "errors_test.html")
+ f, err := os.Create(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(path)
+
+ const content = "This is a error page"
+ _, err = f.WriteString(content)
+ if err != nil {
+ t.Fatal(err)
+ }
+ f.Close()
+
+ buf := bytes.Buffer{}
+ em := ErrorHandler{
+ ErrorPages: map[int]string{
+ http.StatusNotFound: path,
+ http.StatusForbidden: "not_exist_file",
+ },
+ Log: log.New(&buf, "", 0),
+ }
+ _, notExistErr := os.Open("not_exist_file")
+
+ testErr := errors.New("test error")
+ tests := []struct {
+ next middleware.Handler
+ expectedCode int
+ expectedBody string
+ expectedLog string
+ expectedErr error
+ }{
+ {
+ next: genErrorHandler(http.StatusOK, nil, "normal"),
+ expectedCode: http.StatusOK,
+ expectedBody: "normal",
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusMovedPermanently, testErr, ""),
+ expectedCode: http.StatusMovedPermanently,
+ expectedBody: "",
+ expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr),
+ expectedErr: testErr,
+ },
+ {
+ next: genErrorHandler(http.StatusBadRequest, nil, ""),
+ expectedCode: 0,
+ expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest,
+ http.StatusText(http.StatusBadRequest)),
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusNotFound, nil, ""),
+ expectedCode: 0,
+ expectedBody: content,
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusForbidden, nil, ""),
+ expectedCode: 0,
+ expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden,
+ http.StatusText(http.StatusForbidden)),
+ expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n",
+ http.StatusForbidden, notExistErr),
+ expectedErr: nil,
+ },
+ }
+
+ req, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i, test := range tests {
+ em.Next = test.next
+ buf.Reset()
+ rec := httptest.NewRecorder()
+ code, err := em.ServeHTTP(rec, req)
+
+ if err != test.expectedErr {
+ t.Errorf("Test %d: Expected error %v, but got %v",
+ i, test.expectedErr, err)
+ }
+ if code != test.expectedCode {
+ t.Errorf("Test %d: Expected status code %d, but got %d",
+ i, test.expectedCode, code)
+ }
+ if body := rec.Body.String(); body != test.expectedBody {
+ t.Errorf("Test %d: Expected body %q, but got %q",
+ i, test.expectedBody, body)
+ }
+ if log := buf.String(); !strings.Contains(log, test.expectedLog) {
+ t.Errorf("Test %d: Expected log %q, but got %q",
+ i, test.expectedLog, log)
+ }
+ }
+}
+
+func TestVisibleErrorWithPanic(t *testing.T) {
+ const panicMsg = "I'm a panic"
+ eh := ErrorHandler{
+ ErrorPages: make(map[int]string),
+ Debug: true,
+ Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
+ panic(panicMsg)
+ }),
+ }
+
+ req, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rec := httptest.NewRecorder()
+
+ code, err := eh.ServeHTTP(rec, req)
+
+ if code != 0 {
+ t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code)
+ }
+ if err != nil {
+ t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err)
+ }
+
+ body := rec.Body.String()
+
+ if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") {
+ t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body)
+ }
+ if !strings.Contains(body, panicMsg) {
+ t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body)
+ }
+ if len(body) < 500 {
+ t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body))
+ }
+}
+
+func genErrorHandler(status int, err error, body string) middleware.Handler {
+ return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
+ if len(body) > 0 {
+ w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+ fmt.Fprint(w, body)
+ }
+ return status, err
+ })
+}
diff --git a/middleware/etcd/TODO b/middleware/etcd/TODO
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/middleware/etcd/TODO
diff --git a/middleware/exchange.go b/middleware/exchange.go
new file mode 100644
index 000000000..837fa3cdc
--- /dev/null
+++ b/middleware/exchange.go
@@ -0,0 +1,10 @@
+package middleware
+
+import "github.com/miekg/dns"
+
+// Exchang sends message m to the server.
+// TODO(miek): optionally it can do retries of other silly stuff.
+func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) {
+ r, _, err := c.Exchange(m, server)
+ return r, err
+}
diff --git a/middleware/file/file.go b/middleware/file/file.go
new file mode 100644
index 000000000..5bc5a3a3a
--- /dev/null
+++ b/middleware/file/file.go
@@ -0,0 +1,89 @@
+package file
+
+// TODO(miek): the zone's implementation is basically non-existent
+// we return a list and when searching for an answer we iterate
+// over the list. This must be moved to a tree-like structure and
+// have some fluff for DNSSEC (and be memory efficient).
+
+import (
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+type (
+ File struct {
+ Next middleware.Handler
+ Zones Zones
+ // Maybe a list of all zones as well, as a []string?
+ }
+
+ Zone []dns.RR
+ Zones struct {
+ Z map[string]Zone // utterly braindead impl. TODO(miek): fix
+ Names []string
+ }
+)
+
+func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ context := middleware.Context{W: w, Req: r}
+ qname := context.Name()
+ zone := middleware.Zones(f.Zones.Names).Matches(qname)
+ if zone == "" {
+ return f.Next.ServeDNS(w, r)
+ }
+
+ names, nodata := f.Zones.Z[zone].lookup(qname, context.QType())
+ var answer *dns.Msg
+ switch {
+ case nodata:
+ answer = context.AnswerMessage()
+ answer.Ns = names
+ case len(names) == 0:
+ answer = context.AnswerMessage()
+ answer.Ns = names
+ answer.Rcode = dns.RcodeNameError
+ case len(names) > 0:
+ answer = context.AnswerMessage()
+ answer.Answer = names
+ default:
+ answer = context.ErrorMessage(dns.RcodeServerFailure)
+ }
+ // Check return size, etc. TODO(miek)
+ w.WriteMsg(answer)
+ return 0, nil
+}
+
+// Lookup will try to find qname and qtype in z. It returns the
+// records found *or* a boolean saying NODATA. If the answer
+// is NODATA then the RR returned is the SOA record.
+//
+// TODO(miek): EXTREMELY STUPID IMPLEMENTATION.
+// Doesn't do much, no delegation, no cname, nothing really, etc.
+// TODO(miek): even NODATA looks broken
+func (z Zone) lookup(qname string, qtype uint16) ([]dns.RR, bool) {
+ var (
+ nodata bool
+ rep []dns.RR
+ soa dns.RR
+ )
+
+ for _, rr := range z {
+ if rr.Header().Rrtype == dns.TypeSOA {
+ soa = rr
+ }
+ // Match function in Go DNS?
+ if strings.ToLower(rr.Header().Name) == qname {
+ if rr.Header().Rrtype == qtype {
+ rep = append(rep, rr)
+ nodata = false
+ }
+
+ }
+ }
+ if nodata {
+ return []dns.RR{soa}, true
+ }
+ return rep, false
+}
diff --git a/middleware/file/file_test.go b/middleware/file/file_test.go
new file mode 100644
index 000000000..54584b5cc
--- /dev/null
+++ b/middleware/file/file_test.go
@@ -0,0 +1,325 @@
+package file
+
+import (
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
+var ErrCustom = errors.New("Custom Error")
+
+// testFiles is a map with relative paths to test files as keys and file content as values.
+// The map represents the following structure:
+// - $TEMP/caddy_testdir/
+// '-- file1.html
+// '-- dirwithindex/
+// '---- index.html
+// '-- dir/
+// '---- file2.html
+// '---- hidden.html
+var testFiles = map[string]string{
+ "file1.html": "<h1>file1.html</h1>",
+ filepath.Join("dirwithindex", "index.html"): "<h1>dirwithindex/index.html</h1>",
+ filepath.Join("dir", "file2.html"): "<h1>dir/file2.html</h1>",
+ filepath.Join("dir", "hidden.html"): "<h1>dir/hidden.html</h1>",
+}
+
+// TestServeHTTP covers positive scenarios when serving files.
+func TestServeHTTP(t *testing.T) {
+
+ beforeServeHTTPTest(t)
+ defer afterServeHTTPTest(t)
+
+ fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"})
+
+ movedPermanently := "Moved Permanently"
+
+ tests := []struct {
+ url string
+
+ expectedStatus int
+ expectedBodyContent string
+ }{
+ // Test 0 - access without any path
+ {
+ url: "https://foo",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 1 - access root (without index.html)
+ {
+ url: "https://foo/",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 2 - access existing file
+ {
+ url: "https://foo/file1.html",
+ expectedStatus: http.StatusOK,
+ expectedBodyContent: testFiles["file1.html"],
+ },
+ // Test 3 - access folder with index file with trailing slash
+ {
+ url: "https://foo/dirwithindex/",
+ expectedStatus: http.StatusOK,
+ expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
+ },
+ // Test 4 - access folder with index file without trailing slash
+ {
+ url: "https://foo/dirwithindex",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ // Test 5 - access folder without index file
+ {
+ url: "https://foo/dir/",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 6 - access folder without trailing slash
+ {
+ url: "https://foo/dir",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ // Test 6 - access file with trailing slash
+ {
+ url: "https://foo/file1.html/",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ // Test 7 - access not existing path
+ {
+ url: "https://foo/not_existing",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 8 - access a file, marked as hidden
+ {
+ url: "https://foo/dir/hidden.html",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 9 - access a index file directly
+ {
+ url: "https://foo/dirwithindex/index.html",
+ expectedStatus: http.StatusOK,
+ expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
+ },
+ // Test 10 - send a request with query params
+ {
+ url: "https://foo/dir?param1=val",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ }
+
+ for i, test := range tests {
+ responseRecorder := httptest.NewRecorder()
+ request, err := http.NewRequest("GET", test.url, strings.NewReader(""))
+ status, err := fileserver.ServeHTTP(responseRecorder, request)
+
+ // check if error matches expectations
+ if err != nil {
+ t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err)
+ }
+
+ // check status code
+ if test.expectedStatus != status {
+ t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
+ }
+
+ // check body content
+ if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) {
+ t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String())
+ }
+ }
+
+}
+
+// beforeServeHTTPTest creates a test directory with the structure, defined in the variable testFiles
+func beforeServeHTTPTest(t *testing.T) {
+ // make the root test dir
+ err := os.Mkdir(testDir, os.ModePerm)
+ if err != nil {
+ if !os.IsExist(err) {
+ t.Fatalf("Failed to create test dir. Error was: %v", err)
+ return
+ }
+ }
+
+ for relFile, fileContent := range testFiles {
+ absFile := filepath.Join(testDir, relFile)
+
+ // make sure the parent directories exist
+ parentDir := filepath.Dir(absFile)
+ _, err = os.Stat(parentDir)
+ if err != nil {
+ os.MkdirAll(parentDir, os.ModePerm)
+ }
+
+ // now create the test files
+ f, err := os.Create(absFile)
+ if err != nil {
+ t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err)
+ return
+ }
+
+ // and fill them with content
+ _, err = f.WriteString(fileContent)
+ if err != nil {
+ t.Fatalf("Failed to write to %s. Error was: %v", absFile, err)
+ return
+ }
+ f.Close()
+ }
+
+}
+
+// afterServeHTTPTest removes the test dir and all its content
+func afterServeHTTPTest(t *testing.T) {
+ // cleans up everything under the test dir. No need to clean the individual files.
+ err := os.RemoveAll(testDir)
+ if err != nil {
+ t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err)
+ }
+}
+
+// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err
+type failingFS struct {
+ err error // the error to return when Open is called
+ fileImpl http.File // inject the file implementation
+}
+
+// Open returns the assigned failingFile and error
+func (f failingFS) Open(path string) (http.File, error) {
+ return f.fileImpl, f.err
+}
+
+// failingFile implements http.File but returns a predefined error on every Stat() method call.
+type failingFile struct {
+ http.File
+ err error
+}
+
+// Stat returns nil FileInfo and the provided error on every call
+func (ff failingFile) Stat() (os.FileInfo, error) {
+ return nil, ff.err
+}
+
+// Close is noop and returns no error
+func (ff failingFile) Close() error {
+ return nil
+}
+
+// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors.
+func TestServeHTTPFailingFS(t *testing.T) {
+
+ tests := []struct {
+ fsErr error
+ expectedStatus int
+ expectedErr error
+ expectedHeaders map[string]string
+ }{
+ {
+ fsErr: os.ErrNotExist,
+ expectedStatus: http.StatusNotFound,
+ expectedErr: nil,
+ },
+ {
+ fsErr: os.ErrPermission,
+ expectedStatus: http.StatusForbidden,
+ expectedErr: os.ErrPermission,
+ },
+ {
+ fsErr: ErrCustom,
+ expectedStatus: http.StatusServiceUnavailable,
+ expectedErr: ErrCustom,
+ expectedHeaders: map[string]string{"Retry-After": "5"},
+ },
+ }
+
+ for i, test := range tests {
+ // initialize a file server with the failing FileSystem
+ fileserver := FileServer(failingFS{err: test.fsErr}, nil)
+
+ // prepare the request and response
+ request, err := http.NewRequest("GET", "https://foo/", nil)
+ if err != nil {
+ t.Fatalf("Failed to build request. Error was: %v", err)
+ }
+ responseRecorder := httptest.NewRecorder()
+
+ status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
+
+ // check the status
+ if status != test.expectedStatus {
+ t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
+ }
+
+ // check the error
+ if actualErr != test.expectedErr {
+ t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
+ }
+
+ // check the headers - a special case for server under load
+ if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 {
+ for expectedKey, expectedVal := range test.expectedHeaders {
+ actualVal := responseRecorder.Header().Get(expectedKey)
+ if expectedVal != actualVal {
+ t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal)
+ }
+ }
+ }
+ }
+}
+
+// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails.
+func TestServeHTTPFailingStat(t *testing.T) {
+
+ tests := []struct {
+ statErr error
+ expectedStatus int
+ expectedErr error
+ }{
+ {
+ statErr: os.ErrNotExist,
+ expectedStatus: http.StatusNotFound,
+ expectedErr: nil,
+ },
+ {
+ statErr: os.ErrPermission,
+ expectedStatus: http.StatusForbidden,
+ expectedErr: os.ErrPermission,
+ },
+ {
+ statErr: ErrCustom,
+ expectedStatus: http.StatusInternalServerError,
+ expectedErr: ErrCustom,
+ },
+ }
+
+ for i, test := range tests {
+ // initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will
+ fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil)
+
+ // prepare the request and response
+ request, err := http.NewRequest("GET", "https://foo/", nil)
+ if err != nil {
+ t.Fatalf("Failed to build request. Error was: %v", err)
+ }
+ responseRecorder := httptest.NewRecorder()
+
+ status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
+
+ // check the status
+ if status != test.expectedStatus {
+ t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
+ }
+
+ // check the error
+ if actualErr != test.expectedErr {
+ t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
+ }
+ }
+}
diff --git a/middleware/host.go b/middleware/host.go
new file mode 100644
index 000000000..17ecedb5f
--- /dev/null
+++ b/middleware/host.go
@@ -0,0 +1,22 @@
+package middleware
+
+import (
+ "net"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+// Host represents a host from the Caddyfile, may contain port.
+type Host string
+
+// Standard host will return the host portion of host, stripping
+// of any port. The host will also be fully qualified and lowercased.
+func (h Host) StandardHost() string {
+ // separate host and port
+ host, _, err := net.SplitHostPort(string(h))
+ if err != nil {
+ host, _, _ = net.SplitHostPort(string(h) + ":")
+ }
+ return strings.ToLower(dns.Fqdn(host))
+}
diff --git a/middleware/log/log.go b/middleware/log/log.go
new file mode 100644
index 000000000..109add9f5
--- /dev/null
+++ b/middleware/log/log.go
@@ -0,0 +1,66 @@
+// Package log implements basic but useful request (access) logging middleware.
+package log
+
+import (
+ "log"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// Logger is a basic request logging middleware.
+type Logger struct {
+ Next middleware.Handler
+ Rules []Rule
+ ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
+}
+
+func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ for _, rule := range l.Rules {
+ /*
+ if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
+ responseRecorder := middleware.NewResponseRecorder(w)
+ status, err := l.Next.ServeHTTP(responseRecorder, r)
+ if status >= 400 {
+ // There was an error up the chain, but no response has been written yet.
+ // The error must be handled here so the log entry will record the response size.
+ if l.ErrorFunc != nil {
+ l.ErrorFunc(responseRecorder, r, status)
+ } else {
+ // Default failover error handler
+ responseRecorder.WriteHeader(status)
+ fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status))
+ }
+ status = 0
+ }
+ rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue)
+ rule.Log.Println(rep.Replace(rule.Format))
+ return status, err
+ }
+ */
+ rule = rule
+ }
+ return l.Next.ServeDNS(w, r)
+}
+
+// Rule configures the logging middleware.
+type Rule struct {
+ PathScope string
+ OutputFile string
+ Format string
+ Log *log.Logger
+ Roller *middleware.LogRoller
+}
+
+const (
+ // DefaultLogFilename is the default log filename.
+ DefaultLogFilename = "access.log"
+ // CommonLogFormat is the common log format.
+ CommonLogFormat = `{remote} ` + CommonLogEmptyValue + ` [{when}] "{type} {name} {proto}" {rcode} {size}`
+ // CommonLogEmptyValue is the common empty log value.
+ CommonLogEmptyValue = "-"
+ // CombinedLogFormat is the combined log format.
+ CombinedLogFormat = CommonLogFormat + ` "{>Referer}" "{>User-Agent}"` // Something here as well
+ // DefaultLogFormat is the default log format.
+ DefaultLogFormat = CommonLogFormat
+)
diff --git a/middleware/log/log_test.go b/middleware/log/log_test.go
new file mode 100644
index 000000000..40560e4c0
--- /dev/null
+++ b/middleware/log/log_test.go
@@ -0,0 +1,48 @@
+package log
+
+import (
+ "bytes"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+type erroringMiddleware struct{}
+
+func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
+ return http.StatusNotFound, nil
+}
+
+func TestLoggedStatus(t *testing.T) {
+ var f bytes.Buffer
+ var next erroringMiddleware
+ rule := Rule{
+ PathScope: "/",
+ Format: DefaultLogFormat,
+ Log: log.New(&f, "", 0),
+ }
+
+ logger := Logger{
+ Rules: []Rule{rule},
+ Next: next,
+ }
+
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rec := httptest.NewRecorder()
+
+ status, err := logger.ServeHTTP(rec, r)
+ if status != 0 {
+ t.Error("Expected status to be 0 - was", status)
+ }
+
+ logged := f.String()
+ if !strings.Contains(logged, "404 13") {
+ t.Error("Expected 404 to be logged. Logged string -", logged)
+ }
+}
diff --git a/middleware/middleware.go b/middleware/middleware.go
new file mode 100644
index 000000000..436ec86e9
--- /dev/null
+++ b/middleware/middleware.go
@@ -0,0 +1,105 @@
+// Package middleware provides some types and functions common among middleware.
+package middleware
+
+import (
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+type (
+ // Middleware is the middle layer which represents the traditional
+ // idea of middleware: it chains one Handler to the next by being
+ // passed the next Handler in the chain.
+ Middleware func(Handler) Handler
+
+ // Handler is like dns.Handler except ServeDNS may return an rcode
+ // and/or error.
+ //
+ // If ServeDNS writes to the response body, it should return a status
+ // code of 0. This signals to other handlers above it that the response
+ // body is already written, and that they should not write to it also.
+ //
+ // If ServeDNS encounters an error, it should return the error value
+ // so it can be logged by designated error-handling middleware.
+ //
+ // If writing a response after calling another ServeDNS method, the
+ // returned rcode SHOULD be used when writing the response.
+ //
+ // If handling errors after calling another ServeDNS method, the
+ // returned error value SHOULD be logged or handled accordingly.
+ //
+ // Otherwise, return values should be propagated down the middleware
+ // chain by returning them unchanged.
+ Handler interface {
+ ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error)
+ }
+
+ // HandlerFunc is a convenience type like dns.HandlerFunc, except
+ // ServeDNS returns an rcode and an error. See Handler
+ // documentation for more information.
+ HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error)
+)
+
+// ServeDNS implements the Handler interface.
+func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ return f(w, r)
+}
+
+// IndexFile looks for a file in /root/fpath/indexFile for each string
+// in indexFiles. If an index file is found, it returns the root-relative
+// path to the file and true. If no index file is found, empty string
+// and false is returned. fpath must end in a forward slash '/'
+// otherwise no index files will be tried (directory paths must end
+// in a forward slash according to HTTP).
+//
+// All paths passed into and returned from this function use '/' as the
+// path separator, just like URLs. IndexFle handles path manipulation
+// internally for systems that use different path separators.
+/*
+func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, bool) {
+ if fpath[len(fpath)-1] != '/' || root == nil {
+ return "", false
+ }
+ for _, indexFile := range indexFiles {
+ // func (http.FileSystem).Open wants all paths separated by "/",
+ // regardless of operating system convention, so use
+ // path.Join instead of filepath.Join
+ fp := path.Join(fpath, indexFile)
+ f, err := root.Open(fp)
+ if err == nil {
+ f.Close()
+ return fp, true
+ }
+ }
+ return "", false
+}
+
+// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it
+// as a Last-Modified header to the ResponseWriter. If the modTime is in the future
+// the current time is used instead.
+func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) {
+ if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) {
+ // the time does not appear to be valid. Don't put it in the response
+ return
+ }
+
+ // RFC 2616 - Section 14.29 - Last-Modified:
+ // An origin server MUST NOT send a Last-Modified date which is later than the
+ // server's time of message origination. In such cases, where the resource's last
+ // modification would indicate some time in the future, the server MUST replace
+ // that date with the message origination date.
+ now := currentTime()
+ if modTime.After(now) {
+ modTime = now
+ }
+
+ w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat))
+}
+*/
+
+// currentTime, as it is defined here, returns time.Now().
+// It's defined as a variable for mocking time in tests.
+var currentTime = func() time.Time {
+ return time.Now()
+}
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
new file mode 100644
index 000000000..62fa4e250
--- /dev/null
+++ b/middleware/middleware_test.go
@@ -0,0 +1,108 @@
+package middleware
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+)
+
+func TestIndexfile(t *testing.T) {
+ tests := []struct {
+ rootDir http.FileSystem
+ fpath string
+ indexFiles []string
+ shouldErr bool
+ expectedFilePath string //retun value
+ expectedBoolValue bool //return value
+ }{
+ {
+ http.Dir("./templates/testdata"),
+ "/images/",
+ []string{"img.htm"},
+ false,
+ "/images/img.htm",
+ true,
+ },
+ }
+ for i, test := range tests {
+ actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
+ if actualBoolValue == true && test.shouldErr {
+ t.Errorf("Test %d didn't error, but it should have", i)
+ } else if actualBoolValue != true && !test.shouldErr {
+ t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
+ }
+ if actualFilePath != test.expectedFilePath {
+ t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
+ i, test.expectedFilePath, actualFilePath)
+
+ }
+ if actualBoolValue != test.expectedBoolValue {
+ t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
+ i, test.expectedBoolValue, actualBoolValue)
+
+ }
+ }
+}
+
+func TestSetLastModified(t *testing.T) {
+ nowTime := time.Now()
+
+ // ovewrite the function to return reliable time
+ originalGetCurrentTimeFunc := currentTime
+ currentTime = func() time.Time {
+ return nowTime
+ }
+ defer func() {
+ currentTime = originalGetCurrentTimeFunc
+ }()
+
+ pastTime := nowTime.Truncate(1 * time.Hour)
+ futureTime := nowTime.Add(1 * time.Hour)
+
+ tests := []struct {
+ inputModTime time.Time
+ expectedIsHeaderSet bool
+ expectedLastModified string
+ }{
+ {
+ inputModTime: pastTime,
+ expectedIsHeaderSet: true,
+ expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
+ },
+ {
+ inputModTime: nowTime,
+ expectedIsHeaderSet: true,
+ expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
+ },
+ {
+ inputModTime: futureTime,
+ expectedIsHeaderSet: true,
+ expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
+ },
+ {
+ inputModTime: time.Time{},
+ expectedIsHeaderSet: false,
+ },
+ }
+
+ for i, test := range tests {
+ responseRecorder := httptest.NewRecorder()
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ SetLastModifiedHeader(responseRecorder, test.inputModTime)
+ actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
+
+ if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
+ t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
+ }
+
+ if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
+ t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
+ }
+
+ if test.expectedLastModified != actualLastModifiedHeader {
+ t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
+ }
+ }
+}
diff --git a/middleware/path.go b/middleware/path.go
new file mode 100644
index 000000000..1ffb64b76
--- /dev/null
+++ b/middleware/path.go
@@ -0,0 +1,18 @@
+package middleware
+
+import "strings"
+
+
+// TODO(miek): matches for names.
+
+// Path represents a URI path, maybe with pattern characters.
+type Path string
+
+// Matches checks to see if other matches p.
+//
+// Path matching will probably not always be a direct
+// comparison; this method assures that paths can be
+// easily and consistently matched.
+func (p Path) Matches(other string) bool {
+ return strings.HasPrefix(string(p), other)
+}
diff --git a/middleware/prometheus/handler.go b/middleware/prometheus/handler.go
new file mode 100644
index 000000000..eb82b8aff
--- /dev/null
+++ b/middleware/prometheus/handler.go
@@ -0,0 +1,31 @@
+package metrics
+
+import (
+ "strconv"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ context := middleware.Context{W: w, Req: r}
+
+ qname := context.Name()
+ qtype := context.Type()
+ zone := middleware.Zones(m.ZoneNames).Matches(qname)
+ if zone == "" {
+ zone = "."
+ }
+
+ // Record response to get status code and size of the reply.
+ rw := middleware.NewResponseRecorder(w)
+ status, err := m.Next.ServeDNS(rw, r)
+
+ requestCount.WithLabelValues(zone, qtype).Inc()
+ requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
+ responseSize.WithLabelValues(zone).Observe(float64(rw.Size()))
+ responseRcode.WithLabelValues(zone, strconv.Itoa(rw.Rcode())).Inc()
+
+ return status, err
+}
diff --git a/middleware/prometheus/metrics.go b/middleware/prometheus/metrics.go
new file mode 100644
index 000000000..4c989f640
--- /dev/null
+++ b/middleware/prometheus/metrics.go
@@ -0,0 +1,80 @@
+package metrics
+
+import (
+ "fmt"
+ "net/http"
+ "sync"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/prometheus/client_golang/prometheus"
+)
+
+const namespace = "daddy"
+
+var (
+ requestCount *prometheus.CounterVec
+ requestDuration *prometheus.HistogramVec
+ responseSize *prometheus.HistogramVec
+ responseRcode *prometheus.CounterVec
+)
+
+const path = "/metrics"
+
+// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics
+type Metrics struct {
+ Next middleware.Handler
+ Addr string // where to we listen
+ Once sync.Once
+ ZoneNames []string
+}
+
+func (m *Metrics) Start() error {
+ m.Once.Do(func() {
+ define("")
+
+ prometheus.MustRegister(requestCount)
+ prometheus.MustRegister(requestDuration)
+ prometheus.MustRegister(responseSize)
+ prometheus.MustRegister(responseRcode)
+
+ http.Handle(path, prometheus.Handler())
+ go func() {
+ fmt.Errorf("%s", http.ListenAndServe(m.Addr, nil))
+ }()
+ })
+ return nil
+}
+
+func define(subsystem string) {
+ if subsystem == "" {
+ subsystem = "dns"
+ }
+ requestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "request_count_total",
+ Help: "Counter of DNS requests made per zone and type.",
+ }, []string{"zone", "qtype"})
+
+ requestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "request_duration_seconds",
+ Help: "Histogram of the time (in seconds) each request took.",
+ }, []string{"zone"})
+
+ responseSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "response_size_bytes",
+ Help: "Size of the returns response in bytes.",
+ Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3},
+ }, []string{"zone"})
+
+ responseRcode = prometheus.NewCounterVec(prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "rcode_code_count_total",
+ Help: "Counter of response status codes.",
+ }, []string{"zone", "rcode"})
+}
diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go
new file mode 100644
index 000000000..a2522bcb1
--- /dev/null
+++ b/middleware/proxy/policy.go
@@ -0,0 +1,101 @@
+package proxy
+
+import (
+ "math/rand"
+ "sync/atomic"
+)
+
+// HostPool is a collection of UpstreamHosts.
+type HostPool []*UpstreamHost
+
+// Policy decides how a host will be selected from a pool.
+type Policy interface {
+ Select(pool HostPool) *UpstreamHost
+}
+
+func init() {
+ RegisterPolicy("random", func() Policy { return &Random{} })
+ RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
+ RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
+}
+
+// Random is a policy that selects up hosts from a pool at random.
+type Random struct{}
+
+// Select selects an up host at random from the specified pool.
+func (r *Random) Select(pool HostPool) *UpstreamHost {
+ // instead of just generating a random index
+ // this is done to prevent selecting a down host
+ var randHost *UpstreamHost
+ count := 0
+ for _, host := range pool {
+ if host.Down() {
+ continue
+ }
+ count++
+ if count == 1 {
+ randHost = host
+ } else {
+ r := rand.Int() % count
+ if r == (count - 1) {
+ randHost = host
+ }
+ }
+ }
+ return randHost
+}
+
+// LeastConn is a policy that selects the host with the least connections.
+type LeastConn struct{}
+
+// Select selects the up host with the least number of connections in the
+// pool. If more than one host has the same least number of connections,
+// one of the hosts is chosen at random.
+func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
+ var bestHost *UpstreamHost
+ count := 0
+ leastConn := int64(1<<63 - 1)
+ for _, host := range pool {
+ if host.Down() {
+ continue
+ }
+ hostConns := host.Conns
+ if hostConns < leastConn {
+ bestHost = host
+ leastConn = hostConns
+ count = 1
+ } else if hostConns == leastConn {
+ // randomly select host among hosts with least connections
+ count++
+ if count == 1 {
+ bestHost = host
+ } else {
+ r := rand.Int() % count
+ if r == (count - 1) {
+ bestHost = host
+ }
+ }
+ }
+ }
+ return bestHost
+}
+
+// RoundRobin is a policy that selects hosts based on round robin ordering.
+type RoundRobin struct {
+ Robin uint32
+}
+
+// Select selects an up host from the pool using a round robin ordering scheme.
+func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
+ poolLen := uint32(len(pool))
+ selection := atomic.AddUint32(&r.Robin, 1) % poolLen
+ host := pool[selection]
+ // if the currently selected host is down, just ffwd to up host
+ for i := uint32(1); host.Down() && i < poolLen; i++ {
+ host = pool[(selection+i)%poolLen]
+ }
+ if host.Down() {
+ return nil
+ }
+ return host
+}
diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go
new file mode 100644
index 000000000..8f4f1f792
--- /dev/null
+++ b/middleware/proxy/policy_test.go
@@ -0,0 +1,87 @@
+package proxy
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+)
+
+var workableServer *httptest.Server
+
+func TestMain(m *testing.M) {
+ workableServer = httptest.NewServer(http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ // do nothing
+ }))
+ r := m.Run()
+ workableServer.Close()
+ os.Exit(r)
+}
+
+type customPolicy struct{}
+
+func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
+ return pool[0]
+}
+
+func testPool() HostPool {
+ pool := []*UpstreamHost{
+ {
+ Name: workableServer.URL, // this should resolve (healthcheck test)
+ },
+ {
+ Name: "http://shouldnot.resolve", // this shouldn't
+ },
+ {
+ Name: "http://C",
+ },
+ }
+ return HostPool(pool)
+}
+
+func TestRoundRobinPolicy(t *testing.T) {
+ pool := testPool()
+ rrPolicy := &RoundRobin{}
+ h := rrPolicy.Select(pool)
+ // First selected host is 1, because counter starts at 0
+ // and increments before host is selected
+ if h != pool[1] {
+ t.Error("Expected first round robin host to be second host in the pool.")
+ }
+ h = rrPolicy.Select(pool)
+ if h != pool[2] {
+ t.Error("Expected second round robin host to be third host in the pool.")
+ }
+ // mark host as down
+ pool[0].Unhealthy = true
+ h = rrPolicy.Select(pool)
+ if h != pool[1] {
+ t.Error("Expected third round robin host to be first host in the pool.")
+ }
+}
+
+func TestLeastConnPolicy(t *testing.T) {
+ pool := testPool()
+ lcPolicy := &LeastConn{}
+ pool[0].Conns = 10
+ pool[1].Conns = 10
+ h := lcPolicy.Select(pool)
+ if h != pool[2] {
+ t.Error("Expected least connection host to be third host.")
+ }
+ pool[2].Conns = 100
+ h = lcPolicy.Select(pool)
+ if h != pool[0] && h != pool[1] {
+ t.Error("Expected least connection host to be first or second host.")
+ }
+}
+
+func TestCustomPolicy(t *testing.T) {
+ pool := testPool()
+ customPolicy := &customPolicy{}
+ h := customPolicy.Select(pool)
+ if h != pool[0] {
+ t.Error("Expected custom policy host to be the first host.")
+ }
+}
diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go
new file mode 100644
index 000000000..169e41b61
--- /dev/null
+++ b/middleware/proxy/proxy.go
@@ -0,0 +1,120 @@
+// Package proxy is middleware that proxies requests.
+package proxy
+
+import (
+ "errors"
+ "net/http"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+var errUnreachable = errors.New("unreachable backend")
+
+// Proxy represents a middleware instance that can proxy requests.
+type Proxy struct {
+ Next middleware.Handler
+ Client Client
+ Upstreams []Upstream
+}
+
+type Client struct {
+ UDP *dns.Client
+ TCP *dns.Client
+}
+
+// Upstream manages a pool of proxy upstream hosts. Select should return a
+// suitable upstream host, or nil if no such hosts are available.
+type Upstream interface {
+ // The domain name this upstream host should be routed on.
+ From() string
+ // Selects an upstream host to be routed to.
+ Select() *UpstreamHost
+ // Checks if subpdomain is not an ignored.
+ IsAllowedPath(string) bool
+}
+
+// UpstreamHostDownFunc can be used to customize how Down behaves.
+type UpstreamHostDownFunc func(*UpstreamHost) bool
+
+// UpstreamHost represents a single proxy upstream
+type UpstreamHost struct {
+ Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
+ Name string // IP address (and port) of this upstream host
+ Fails int32
+ FailTimeout time.Duration
+ Unhealthy bool
+ ExtraHeaders http.Header
+ CheckDown UpstreamHostDownFunc
+ WithoutPathPrefix string
+}
+
+// Down checks whether the upstream host is down or not.
+// Down will try to use uh.CheckDown first, and will fall
+// back to some default criteria if necessary.
+func (uh *UpstreamHost) Down() bool {
+ if uh.CheckDown == nil {
+ // Default settings
+ return uh.Unhealthy || uh.Fails > 0
+ }
+ return uh.CheckDown(uh)
+}
+
+// tryDuration is how long to try upstream hosts; failures result in
+// immediate retries until this duration ends or we get a nil host.
+var tryDuration = 60 * time.Second
+
+// ServeDNS satisfies the middleware.Handler interface.
+func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ for _, upstream := range p.Upstreams {
+ // allowed bla bla bla TODO(miek): fix full proxy spec from caddy
+ start := time.Now()
+
+ // Since Select() should give us "up" hosts, keep retrying
+ // hosts until timeout (or until we get a nil host).
+ for time.Now().Sub(start) < tryDuration {
+ host := upstream.Select()
+ if host == nil {
+ return dns.RcodeServerFailure, errUnreachable
+ }
+ // TODO(miek): PORT!
+ reverseproxy := ReverseProxy{Host: host.Name, Client: p.Client}
+
+ atomic.AddInt64(&host.Conns, 1)
+ backendErr := reverseproxy.ServeDNS(w, r, nil)
+ atomic.AddInt64(&host.Conns, -1)
+ if backendErr == nil {
+ return 0, nil
+ }
+ timeout := host.FailTimeout
+ if timeout == 0 {
+ timeout = 10 * time.Second
+ }
+ atomic.AddInt32(&host.Fails, 1)
+ go func(host *UpstreamHost, timeout time.Duration) {
+ time.Sleep(timeout)
+ atomic.AddInt32(&host.Fails, -1)
+ }(host, timeout)
+ }
+ return dns.RcodeServerFailure, errUnreachable
+ }
+ return p.Next.ServeDNS(w, r)
+}
+
+func Clients() Client {
+ udp := newClient("udp", defaultTimeout)
+ tcp := newClient("tcp", defaultTimeout)
+ return Client{UDP: udp, TCP: tcp}
+}
+
+// newClient returns a new client for proxy requests.
+func newClient(net string, timeout time.Duration) *dns.Client {
+ if timeout == 0 {
+ timeout = defaultTimeout
+ }
+ return &dns.Client{Net: net, ReadTimeout: timeout, WriteTimeout: timeout, SingleInflight: true}
+}
+
+const defaultTimeout = 5 * time.Second
diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go
new file mode 100644
index 000000000..8066874d2
--- /dev/null
+++ b/middleware/proxy/proxy_test.go
@@ -0,0 +1,317 @@
+package proxy
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/net/websocket"
+)
+
+func init() {
+ tryDuration = 50 * time.Millisecond // prevent tests from hanging
+}
+
+func TestReverseProxy(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ var requestReceived bool
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestReceived = true
+ w.Write([]byte("Hello, client"))
+ }))
+ defer backend.Close()
+
+ // set up proxy
+ p := &Proxy{
+ Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
+ }
+
+ // create request and response recorder
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ w := httptest.NewRecorder()
+
+ p.ServeHTTP(w, r)
+
+ if !requestReceived {
+ t.Error("Expected backend to receive request, but it didn't")
+ }
+}
+
+func TestReverseProxyInsecureSkipVerify(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ var requestReceived bool
+ backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestReceived = true
+ w.Write([]byte("Hello, client"))
+ }))
+ defer backend.Close()
+
+ // set up proxy
+ p := &Proxy{
+ Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
+ }
+
+ // create request and response recorder
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ w := httptest.NewRecorder()
+
+ p.ServeHTTP(w, r)
+
+ if !requestReceived {
+ t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't")
+ }
+}
+
+func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
+ // No-op websocket backend simply allows the WS connection to be
+ // accepted then it will be immediately closed. Perfect for testing.
+ wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {}))
+ defer wsNop.Close()
+
+ // Get proxy to use for the test
+ p := newWebSocketTestProxy(wsNop.URL)
+
+ // Create client request
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ r.Header = http.Header{
+ "Connection": {"Upgrade"},
+ "Upgrade": {"websocket"},
+ "Origin": {wsNop.URL},
+ "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
+ "Sec-WebSocket-Version": {"13"},
+ }
+
+ // Capture the request
+ w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)}
+
+ // Booya! Do the test.
+ p.ServeHTTP(w, r)
+
+ // Make sure the backend accepted the WS connection.
+ // Mostly interested in the Upgrade and Connection response headers
+ // and the 101 status code.
+ expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n")
+ actual := w.fakeConn.writeBuf.Bytes()
+ if !bytes.Equal(actual, expected) {
+ t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
+ }
+}
+
+func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
+ // Echo server allows us to test that socket bytes are properly
+ // being proxied.
+ wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
+ io.Copy(ws, ws)
+ }))
+ defer wsEcho.Close()
+
+ // Get proxy to use for the test
+ p := newWebSocketTestProxy(wsEcho.URL)
+
+ // This is a full end-end test, so the proxy handler
+ // has to be part of a server listening on a port. Our
+ // WS client will connect to this test server, not
+ // the echo client directly.
+ echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ p.ServeHTTP(w, r)
+ }))
+ defer echoProxy.Close()
+
+ // Set up WebSocket client
+ url := strings.Replace(echoProxy.URL, "http://", "ws://", 1)
+ ws, err := websocket.Dial(url, "", echoProxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ws.Close()
+
+ // Send test message
+ trialMsg := "Is it working?"
+ websocket.Message.Send(ws, trialMsg)
+
+ // It should be echoed back to us
+ var actualMsg string
+ websocket.Message.Receive(ws, &actualMsg)
+ if actualMsg != trialMsg {
+ t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
+ }
+}
+
+func TestUnixSocketProxy(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ return
+ }
+
+ trialMsg := "Is it working?"
+
+ var proxySuccess bool
+
+ // This is our fake "application" we want to proxy to
+ ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Request was proxied when this is called
+ proxySuccess = true
+
+ fmt.Fprint(w, trialMsg)
+ }))
+
+ // Get absolute path for unix: socket
+ socketPath, err := filepath.Abs("./test_socket")
+ if err != nil {
+ t.Fatalf("Unable to get absolute path: %v", err)
+ }
+
+ // Change httptest.Server listener to listen to unix: socket
+ ln, err := net.Listen("unix", socketPath)
+ if err != nil {
+ t.Fatalf("Unable to listen: %v", err)
+ }
+ ts.Listener = ln
+
+ ts.Start()
+ defer ts.Close()
+
+ url := strings.Replace(ts.URL, "http://", "unix:", 1)
+ p := newWebSocketTestProxy(url)
+
+ echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ p.ServeHTTP(w, r)
+ }))
+ defer echoProxy.Close()
+
+ res, err := http.Get(echoProxy.URL)
+ if err != nil {
+ t.Fatalf("Unable to GET: %v", err)
+ }
+
+ greeting, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatalf("Unable to GET: %v", err)
+ }
+
+ actualMsg := fmt.Sprintf("%s", greeting)
+
+ if !proxySuccess {
+ t.Errorf("Expected request to be proxied, but it wasn't")
+ }
+
+ if actualMsg != trialMsg {
+ t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
+ }
+}
+
+func newFakeUpstream(name string, insecure bool) *fakeUpstream {
+ uri, _ := url.Parse(name)
+ u := &fakeUpstream{
+ name: name,
+ host: &UpstreamHost{
+ Name: name,
+ ReverseProxy: NewSingleHostReverseProxy(uri, ""),
+ },
+ }
+ if insecure {
+ u.host.ReverseProxy.Transport = InsecureTransport
+ }
+ return u
+}
+
+type fakeUpstream struct {
+ name string
+ host *UpstreamHost
+}
+
+func (u *fakeUpstream) From() string {
+ return "/"
+}
+
+func (u *fakeUpstream) Select() *UpstreamHost {
+ return u.host
+}
+
+func (u *fakeUpstream) IsAllowedPath(requestPath string) bool {
+ return true
+}
+
+// newWebSocketTestProxy returns a test proxy that will
+// redirect to the specified backendAddr. The function
+// also sets up the rules/environment for testing WebSocket
+// proxy.
+func newWebSocketTestProxy(backendAddr string) *Proxy {
+ return &Proxy{
+ Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}},
+ }
+}
+
+type fakeWsUpstream struct {
+ name string
+}
+
+func (u *fakeWsUpstream) From() string {
+ return "/"
+}
+
+func (u *fakeWsUpstream) Select() *UpstreamHost {
+ uri, _ := url.Parse(u.name)
+ return &UpstreamHost{
+ Name: u.name,
+ ReverseProxy: NewSingleHostReverseProxy(uri, ""),
+ ExtraHeaders: http.Header{
+ "Connection": {"{>Connection}"},
+ "Upgrade": {"{>Upgrade}"}},
+ }
+}
+
+func (u *fakeWsUpstream) IsAllowedPath(requestPath string) bool {
+ return true
+}
+
+// recorderHijacker is a ResponseRecorder that can
+// be hijacked.
+type recorderHijacker struct {
+ *httptest.ResponseRecorder
+ fakeConn *fakeConn
+}
+
+func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return rh.fakeConn, nil, nil
+}
+
+type fakeConn struct {
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+}
+
+func (c *fakeConn) LocalAddr() net.Addr { return nil }
+func (c *fakeConn) RemoteAddr() net.Addr { return nil }
+func (c *fakeConn) SetDeadline(t time.Time) error { return nil }
+func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
+func (c *fakeConn) Close() error { return nil }
+func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
+func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go
new file mode 100644
index 000000000..6d27da042
--- /dev/null
+++ b/middleware/proxy/reverseproxy.go
@@ -0,0 +1,36 @@
+// Package proxy is middleware that proxies requests.
+package proxy
+
+import (
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+type ReverseProxy struct {
+ Host string
+ Client Client
+}
+
+func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
+ // TODO(miek): use extra!
+ var (
+ reply *dns.Msg
+ err error
+ )
+ context := middleware.Context{W: w, Req: r}
+
+ // tls+tcp ?
+ if context.Proto() == "tcp" {
+ reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
+ } else {
+ reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
+ }
+
+ if err != nil {
+ return err
+ }
+ reply.Compress = true
+ reply.Id = r.Id
+ w.WriteMsg(reply)
+ return nil
+}
diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go
new file mode 100644
index 000000000..092e2351d
--- /dev/null
+++ b/middleware/proxy/upstream.go
@@ -0,0 +1,235 @@
+package proxy
+
+import (
+ "io"
+ "io/ioutil"
+ "net/http"
+ "path"
+ "strconv"
+ "time"
+
+ "github.com/miekg/coredns/core/parse"
+ "github.com/miekg/coredns/middleware"
+)
+
+var (
+ supportedPolicies = make(map[string]func() Policy)
+)
+
+type staticUpstream struct {
+ from string
+ // TODO(miek): allows use to added headers
+ proxyHeaders http.Header // TODO(miek): kill
+ Hosts HostPool
+ Policy Policy
+
+ FailTimeout time.Duration
+ MaxFails int32
+ HealthCheck struct {
+ Path string
+ Interval time.Duration
+ }
+ WithoutPathPrefix string
+ IgnoredSubPaths []string
+}
+
+// NewStaticUpstreams parses the configuration input and sets up
+// static upstreams for the proxy middleware.
+func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
+ var upstreams []Upstream
+ for c.Next() {
+ upstream := &staticUpstream{
+ from: "",
+ proxyHeaders: make(http.Header),
+ Hosts: nil,
+ Policy: &Random{},
+ FailTimeout: 10 * time.Second,
+ MaxFails: 1,
+ }
+
+ if !c.Args(&upstream.from) {
+ return upstreams, c.ArgErr()
+ }
+ to := c.RemainingArgs()
+ if len(to) == 0 {
+ return upstreams, c.ArgErr()
+ }
+
+ for c.NextBlock() {
+ if err := parseBlock(&c, upstream); err != nil {
+ return upstreams, err
+ }
+ }
+
+ upstream.Hosts = make([]*UpstreamHost, len(to))
+ for i, host := range to {
+ uh := &UpstreamHost{
+ Name: host,
+ Conns: 0,
+ Fails: 0,
+ FailTimeout: upstream.FailTimeout,
+ Unhealthy: false,
+ ExtraHeaders: upstream.proxyHeaders,
+ CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
+ return func(uh *UpstreamHost) bool {
+ if uh.Unhealthy {
+ return true
+ }
+ if uh.Fails >= upstream.MaxFails &&
+ upstream.MaxFails != 0 {
+ return true
+ }
+ return false
+ }
+ }(upstream),
+ WithoutPathPrefix: upstream.WithoutPathPrefix,
+ }
+ upstream.Hosts[i] = uh
+ }
+
+ if upstream.HealthCheck.Path != "" {
+ go upstream.HealthCheckWorker(nil)
+ }
+ upstreams = append(upstreams, upstream)
+ }
+ return upstreams, nil
+}
+
+// RegisterPolicy adds a custom policy to the proxy.
+func RegisterPolicy(name string, policy func() Policy) {
+ supportedPolicies[name] = policy
+}
+
+func (u *staticUpstream) From() string {
+ return u.from
+}
+
+func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
+ switch c.Val() {
+ case "policy":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ policyCreateFunc, ok := supportedPolicies[c.Val()]
+ if !ok {
+ return c.ArgErr()
+ }
+ u.Policy = policyCreateFunc()
+ case "fail_timeout":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ u.FailTimeout = dur
+ case "max_fails":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ n, err := strconv.Atoi(c.Val())
+ if err != nil {
+ return err
+ }
+ u.MaxFails = int32(n)
+ case "health_check":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ u.HealthCheck.Path = c.Val()
+ u.HealthCheck.Interval = 30 * time.Second
+ if c.NextArg() {
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ u.HealthCheck.Interval = dur
+ }
+ case "proxy_header":
+ var header, value string
+ if !c.Args(&header, &value) {
+ return c.ArgErr()
+ }
+ u.proxyHeaders.Add(header, value)
+ case "websocket":
+ u.proxyHeaders.Add("Connection", "{>Connection}")
+ u.proxyHeaders.Add("Upgrade", "{>Upgrade}")
+ case "without":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ u.WithoutPathPrefix = c.Val()
+ case "except":
+ ignoredPaths := c.RemainingArgs()
+ if len(ignoredPaths) == 0 {
+ return c.ArgErr()
+ }
+ u.IgnoredSubPaths = ignoredPaths
+ default:
+ return c.Errf("unknown property '%s'", c.Val())
+ }
+ return nil
+}
+
+func (u *staticUpstream) healthCheck() {
+ for _, host := range u.Hosts {
+ hostURL := host.Name + u.HealthCheck.Path
+ if r, err := http.Get(hostURL); err == nil {
+ io.Copy(ioutil.Discard, r.Body)
+ r.Body.Close()
+ host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
+ } else {
+ host.Unhealthy = true
+ }
+ }
+}
+
+func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
+ ticker := time.NewTicker(u.HealthCheck.Interval)
+ u.healthCheck()
+ for {
+ select {
+ case <-ticker.C:
+ u.healthCheck()
+ case <-stop:
+ // TODO: the library should provide a stop channel and global
+ // waitgroup to allow goroutines started by plugins a chance
+ // to clean themselves up.
+ }
+ }
+}
+
+func (u *staticUpstream) Select() *UpstreamHost {
+ pool := u.Hosts
+ if len(pool) == 1 {
+ if pool[0].Down() {
+ return nil
+ }
+ return pool[0]
+ }
+ allDown := true
+ for _, host := range pool {
+ if !host.Down() {
+ allDown = false
+ break
+ }
+ }
+ if allDown {
+ return nil
+ }
+
+ if u.Policy == nil {
+ return (&Random{}).Select(pool)
+ }
+ return u.Policy.Select(pool)
+}
+
+func (u *staticUpstream) IsAllowedPath(requestPath string) bool {
+ for _, ignoredSubPath := range u.IgnoredSubPaths {
+ if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go
new file mode 100644
index 000000000..5b2fdb1da
--- /dev/null
+++ b/middleware/proxy/upstream_test.go
@@ -0,0 +1,83 @@
+package proxy
+
+import (
+ "testing"
+ "time"
+)
+
+func TestHealthCheck(t *testing.T) {
+ upstream := &staticUpstream{
+ from: "",
+ Hosts: testPool(),
+ Policy: &Random{},
+ FailTimeout: 10 * time.Second,
+ MaxFails: 1,
+ }
+ upstream.healthCheck()
+ if upstream.Hosts[0].Down() {
+ t.Error("Expected first host in testpool to not fail healthcheck.")
+ }
+ if !upstream.Hosts[1].Down() {
+ t.Error("Expected second host in testpool to fail healthcheck.")
+ }
+}
+
+func TestSelect(t *testing.T) {
+ upstream := &staticUpstream{
+ from: "",
+ Hosts: testPool()[:3],
+ Policy: &Random{},
+ FailTimeout: 10 * time.Second,
+ MaxFails: 1,
+ }
+ upstream.Hosts[0].Unhealthy = true
+ upstream.Hosts[1].Unhealthy = true
+ upstream.Hosts[2].Unhealthy = true
+ if h := upstream.Select(); h != nil {
+ t.Error("Expected select to return nil as all host are down")
+ }
+ upstream.Hosts[2].Unhealthy = false
+ if h := upstream.Select(); h == nil {
+ t.Error("Expected select to not return nil")
+ }
+}
+
+func TestRegisterPolicy(t *testing.T) {
+ name := "custom"
+ customPolicy := &customPolicy{}
+ RegisterPolicy(name, func() Policy { return customPolicy })
+ if _, ok := supportedPolicies[name]; !ok {
+ t.Error("Expected supportedPolicies to have a custom policy.")
+ }
+
+}
+
+func TestAllowedPaths(t *testing.T) {
+ upstream := &staticUpstream{
+ from: "/proxy",
+ IgnoredSubPaths: []string{"/download", "/static"},
+ }
+ tests := []struct {
+ url string
+ expected bool
+ }{
+ {"/proxy", true},
+ {"/proxy/dl", true},
+ {"/proxy/download", false},
+ {"/proxy/download/static", false},
+ {"/proxy/static", false},
+ {"/proxy/static/download", false},
+ {"/proxy/something/download", true},
+ {"/proxy/something/static", true},
+ {"/proxy//static", false},
+ {"/proxy//static//download", false},
+ {"/proxy//download", false},
+ }
+
+ for i, test := range tests {
+ isAllowed := upstream.IsAllowedPath(test.url)
+ if test.expected != isAllowed {
+ t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed)
+ }
+ }
+}
diff --git a/middleware/recorder.go b/middleware/recorder.go
new file mode 100644
index 000000000..38a7e0e82
--- /dev/null
+++ b/middleware/recorder.go
@@ -0,0 +1,70 @@
+package middleware
+
+import (
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// ResponseRecorder is a type of ResponseWriter that captures
+// the rcode code written to it and also the size of the message
+// written in the response. A rcode code does not have
+// to be written, however, in which case 0 must be assumed.
+// It is best to have the constructor initialize this type
+// with that default status code.
+type ResponseRecorder struct {
+ dns.ResponseWriter
+ rcode int
+ size int
+ start time.Time
+}
+
+// NewResponseRecorder makes and returns a new responseRecorder,
+// which captures the DNS rcode from the ResponseWriter
+// and also the length of the response message written through it.
+func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder {
+ return &ResponseRecorder{
+ ResponseWriter: w,
+ rcode: 0,
+ start: time.Now(),
+ }
+}
+
+// WriteMsg records the status code and calls the
+// underlying ResponseWriter's WriteMsg method.
+func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error {
+ r.rcode = res.Rcode
+ r.size = res.Len()
+ return r.ResponseWriter.WriteMsg(res)
+}
+
+// Write is a wrapper that records the size of the message that gets written.
+func (r *ResponseRecorder) Write(buf []byte) (int, error) {
+ n, err := r.ResponseWriter.Write(buf)
+ if err == nil {
+ r.size += n
+ }
+ return n, err
+}
+
+// Size returns the size.
+func (r *ResponseRecorder) Size() int {
+ return r.size
+}
+
+// Rcode returns the rcode.
+func (r *ResponseRecorder) Rcode() int {
+ return r.rcode
+}
+
+// Start returns the start time of the ResponseRecorder.
+func (r *ResponseRecorder) Start() time.Time {
+ return r.start
+}
+
+// Hijack implements dns.Hijacker. It simply wraps the underlying
+// ResponseWriter's Hijack method if there is one, or returns an error.
+func (r *ResponseRecorder) Hijack() {
+ r.ResponseWriter.Hijack()
+ return
+}
diff --git a/middleware/recorder_test.go b/middleware/recorder_test.go
new file mode 100644
index 000000000..a8c8a5d04
--- /dev/null
+++ b/middleware/recorder_test.go
@@ -0,0 +1,32 @@
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestNewResponseRecorder(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ if !(recordRequest.ResponseWriter == w) {
+ t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
+ }
+ if recordRequest.status != http.StatusOK {
+ t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
+ }
+}
+
+func TestWrite(t *testing.T) {
+ w := httptest.NewRecorder()
+ responseTestString := "test"
+ recordRequest := NewResponseRecorder(w)
+ buf := []byte(responseTestString)
+ recordRequest.Write(buf)
+ if recordRequest.size != len(buf) {
+ t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
+ }
+ if w.Body.String() != responseTestString {
+ t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
+ }
+}
diff --git a/middleware/reflect/reflect.go b/middleware/reflect/reflect.go
new file mode 100644
index 000000000..6d5847b81
--- /dev/null
+++ b/middleware/reflect/reflect.go
@@ -0,0 +1,84 @@
+// Reflect provides middleware that reflects back some client properties.
+// This is the default middleware when Caddy is run without configuration.
+//
+// The left-most label must be `who`.
+// When queried for type A (resp. AAAA), it sends back the IPv4 (resp. v6) address.
+// In the additional section the port number and transport are shown.
+// Basic use pattern:
+//
+// dig @localhost -p 1053 who.miek.nl A
+//
+// ;; ANSWER SECTION:
+// who.miek.nl. 0 IN A 127.0.0.1
+//
+// ;; ADDITIONAL SECTION:
+// who.miek.nl. 0 IN TXT "Port: 56195 (udp)"
+package reflect
+
+import (
+ "errors"
+ "net"
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+type Reflect struct {
+ Next middleware.Handler
+}
+
+func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ context := middleware.Context{Req: r, W: w}
+
+ class := r.Question[0].Qclass
+ qname := r.Question[0].Name
+ i, ok := dns.NextLabel(qname, 0)
+
+ if strings.ToLower(qname[:i]) != who || ok {
+ err := context.ErrorMessage(dns.RcodeFormatError)
+ w.WriteMsg(err)
+ return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
+ }
+
+ answer := new(dns.Msg)
+ answer.SetReply(r)
+ answer.Compress = true
+ answer.Authoritative = true
+
+ ip := context.IP()
+ proto := context.Proto()
+ port, _ := context.Port()
+ family := context.Family()
+ var rr dns.RR
+
+ switch family {
+ case 1:
+ rr = new(dns.A)
+ rr.(*dns.A).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeA, Class: class, Ttl: 0}
+ rr.(*dns.A).A = net.ParseIP(ip).To4()
+ case 2:
+ rr = new(dns.AAAA)
+ rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeAAAA, Class: class, Ttl: 0}
+ rr.(*dns.AAAA).AAAA = net.ParseIP(ip)
+ }
+
+ t := new(dns.TXT)
+ t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
+ t.Txt = []string{"Port: " + port + " (" + proto + ")"}
+
+ switch context.Type() {
+ case "TXT":
+ answer.Answer = append(answer.Answer, t)
+ answer.Extra = append(answer.Extra, rr)
+ default:
+ fallthrough
+ case "AAAA", "A":
+ answer.Answer = append(answer.Answer, rr)
+ answer.Extra = append(answer.Extra, t)
+ }
+ w.WriteMsg(answer)
+ return 0, nil
+}
+
+const who = "who."
diff --git a/middleware/reflect/reflect_test.go b/middleware/reflect/reflect_test.go
new file mode 100644
index 000000000..477a3a573
--- /dev/null
+++ b/middleware/reflect/reflect_test.go
@@ -0,0 +1 @@
+package reflect
diff --git a/middleware/replacer.go b/middleware/replacer.go
new file mode 100644
index 000000000..133da74c5
--- /dev/null
+++ b/middleware/replacer.go
@@ -0,0 +1,98 @@
+package middleware
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Replacer is a type which can replace placeholder
+// substrings in a string with actual values from a
+// http.Request and responseRecorder. Always use
+// NewReplacer to get one of these.
+type Replacer interface {
+ Replace(string) string
+ Set(key, value string)
+}
+
+type replacer struct {
+ replacements map[string]string
+ emptyValue string
+}
+
+// NewReplacer makes a new replacer based on r and rr.
+// Do not create a new replacer until r and rr have all
+// the needed values, because this function copies those
+// values into the replacer. rr may be nil if it is not
+// available. emptyValue should be the string that is used
+// in place of empty string (can still be empty string).
+func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
+ context := Context{W: rr, Req: r}
+ rep := replacer{
+ replacements: map[string]string{
+ "{type}": context.Type(),
+ "{name}": context.Name(),
+ "{class}": context.Class(),
+ "{proto}": context.Proto(),
+ "{when}": func() string {
+ return time.Now().Format(timeFormat)
+ }(),
+ "{remote}": context.IP(),
+ "{port}": func() string {
+ p, _ := context.Port()
+ return p
+ }(),
+ },
+ emptyValue: emptyValue,
+ }
+ if rr != nil {
+ rep.replacements["{rcode}"] = strconv.Itoa(rr.rcode)
+ rep.replacements["{size}"] = strconv.Itoa(rr.size)
+ rep.replacements["{latency}"] = time.Since(rr.start).String()
+ }
+
+ return rep
+}
+
+// Replace performs a replacement of values on s and returns
+// the string with the replaced values.
+func (r replacer) Replace(s string) string {
+ // Header replacements - these are case-insensitive, so we can't just use strings.Replace()
+ for strings.Contains(s, headerReplacer) {
+ idxStart := strings.Index(s, headerReplacer)
+ endOffset := idxStart + len(headerReplacer)
+ idxEnd := strings.Index(s[endOffset:], "}")
+ if idxEnd > -1 {
+ placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
+ replacement := r.replacements[placeholder]
+ if replacement == "" {
+ replacement = r.emptyValue
+ }
+ s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
+ } else {
+ break
+ }
+ }
+
+ // Regular replacements - these are easier because they're case-sensitive
+ for placeholder, replacement := range r.replacements {
+ if replacement == "" {
+ replacement = r.emptyValue
+ }
+ s = strings.Replace(s, placeholder, replacement, -1)
+ }
+
+ return s
+}
+
+// Set sets key to value in the replacements map.
+func (r replacer) Set(key, value string) {
+ r.replacements["{"+key+"}"] = value
+}
+
+const (
+ timeFormat = "02/Jan/2006:15:04:05 -0700"
+ headerReplacer = "{>"
+)
diff --git a/middleware/replacer_test.go b/middleware/replacer_test.go
new file mode 100644
index 000000000..d98bd2de1
--- /dev/null
+++ b/middleware/replacer_test.go
@@ -0,0 +1,124 @@
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestNewReplacer(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ reader := strings.NewReader(`{"username": "dennis"}`)
+
+ request, err := http.NewRequest("POST", "http://localhost", reader)
+ if err != nil {
+ t.Fatal("Request Formation Failed\n")
+ }
+ replaceValues := NewReplacer(request, recordRequest, "")
+
+ switch v := replaceValues.(type) {
+ case replacer:
+
+ if v.replacements["{host}"] != "localhost" {
+ t.Error("Expected host to be localhost")
+ }
+ if v.replacements["{method}"] != "POST" {
+ t.Error("Expected request method to be POST")
+ }
+ if v.replacements["{status}"] != "200" {
+ t.Error("Expected status to be 200")
+ }
+
+ default:
+ t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
+ }
+}
+
+func TestReplace(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ reader := strings.NewReader(`{"username": "dennis"}`)
+
+ request, err := http.NewRequest("POST", "http://localhost", reader)
+ if err != nil {
+ t.Fatal("Request Formation Failed\n")
+ }
+ request.Header.Set("Custom", "foobarbaz")
+ request.Header.Set("ShorterVal", "1")
+ repl := NewReplacer(request, recordRequest, "-")
+
+ if expected, actual := "This host is localhost.", repl.Replace("This host is {host}."); expected != actual {
+ t.Errorf("{host} replacement: expected '%s', got '%s'", expected, actual)
+ }
+ if expected, actual := "This request method is POST.", repl.Replace("This request method is {method}."); expected != actual {
+ t.Errorf("{method} replacement: expected '%s', got '%s'", expected, actual)
+ }
+ if expected, actual := "The response status is 200.", repl.Replace("The response status is {status}."); expected != actual {
+ t.Errorf("{status} replacement: expected '%s', got '%s'", expected, actual)
+ }
+ if expected, actual := "The Custom header is foobarbaz.", repl.Replace("The Custom header is {>Custom}."); expected != actual {
+ t.Errorf("{>Custom} replacement: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test header case-insensitivity
+ if expected, actual := "The cUsToM header is foobarbaz...", repl.Replace("The cUsToM header is {>cUsToM}..."); expected != actual {
+ t.Errorf("{>cUsToM} replacement: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test non-existent header/value
+ if expected, actual := "The Non-Existent header is -.", repl.Replace("The Non-Existent header is {>Non-Existent}."); expected != actual {
+ t.Errorf("{>Non-Existent} replacement: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test bad placeholder
+ if expected, actual := "Bad {host placeholder...", repl.Replace("Bad {host placeholder..."); expected != actual {
+ t.Errorf("bad placeholder: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test bad header placeholder
+ if expected, actual := "Bad {>Custom placeholder", repl.Replace("Bad {>Custom placeholder"); expected != actual {
+ t.Errorf("bad header placeholder: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test bad header placeholder with valid one later
+ if expected, actual := "Bad -", repl.Replace("Bad {>Custom placeholder {>ShorterVal}"); expected != actual {
+ t.Errorf("bad header placeholders: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test shorter header value with multiple placeholders
+ if expected, actual := "Short value 1 then foobarbaz.", repl.Replace("Short value {>ShorterVal} then {>Custom}."); expected != actual {
+ t.Errorf("short value: expected '%s', got '%s'", expected, actual)
+ }
+}
+
+func TestSet(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ reader := strings.NewReader(`{"username": "dennis"}`)
+
+ request, err := http.NewRequest("POST", "http://localhost", reader)
+ if err != nil {
+ t.Fatalf("Request Formation Failed \n")
+ }
+ repl := NewReplacer(request, recordRequest, "")
+
+ repl.Set("host", "getcaddy.com")
+ repl.Set("method", "GET")
+ repl.Set("status", "201")
+ repl.Set("variable", "value")
+
+ if repl.Replace("This host is {host}") != "This host is getcaddy.com" {
+ t.Error("Expected host replacement failed")
+ }
+ if repl.Replace("This request method is {method}") != "This request method is GET" {
+ t.Error("Expected method replacement failed")
+ }
+ if repl.Replace("The response status is {status}") != "The response status is 201" {
+ t.Error("Expected status replacement failed")
+ }
+ if repl.Replace("The value of variable is {variable}") != "The value of variable is value" {
+ t.Error("Expected variable replacement failed")
+ }
+}
diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go
new file mode 100644
index 000000000..ddd4c38b1
--- /dev/null
+++ b/middleware/rewrite/condition.go
@@ -0,0 +1,130 @@
+package rewrite
+
+import (
+ "fmt"
+ "regexp"
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// Operators
+const (
+ Is = "is"
+ Not = "not"
+ Has = "has"
+ NotHas = "not_has"
+ StartsWith = "starts_with"
+ EndsWith = "ends_with"
+ Match = "match"
+ NotMatch = "not_match"
+)
+
+func operatorError(operator string) error {
+ return fmt.Errorf("Invalid operator %v", operator)
+}
+
+func newReplacer(r *dns.Msg) middleware.Replacer {
+ return middleware.NewReplacer(r, nil, "")
+}
+
+// condition is a rewrite condition.
+type condition func(string, string) bool
+
+var conditions = map[string]condition{
+ Is: isFunc,
+ Not: notFunc,
+ Has: hasFunc,
+ NotHas: notHasFunc,
+ StartsWith: startsWithFunc,
+ EndsWith: endsWithFunc,
+ Match: matchFunc,
+ NotMatch: notMatchFunc,
+}
+
+// isFunc is condition for Is operator.
+// It checks for equality.
+func isFunc(a, b string) bool {
+ return a == b
+}
+
+// notFunc is condition for Not operator.
+// It checks for inequality.
+func notFunc(a, b string) bool {
+ return a != b
+}
+
+// hasFunc is condition for Has operator.
+// It checks if b is a substring of a.
+func hasFunc(a, b string) bool {
+ return strings.Contains(a, b)
+}
+
+// notHasFunc is condition for NotHas operator.
+// It checks if b is not a substring of a.
+func notHasFunc(a, b string) bool {
+ return !strings.Contains(a, b)
+}
+
+// startsWithFunc is condition for StartsWith operator.
+// It checks if b is a prefix of a.
+func startsWithFunc(a, b string) bool {
+ return strings.HasPrefix(a, b)
+}
+
+// endsWithFunc is condition for EndsWith operator.
+// It checks if b is a suffix of a.
+func endsWithFunc(a, b string) bool {
+ return strings.HasSuffix(a, b)
+}
+
+// matchFunc is condition for Match operator.
+// It does regexp matching of a against pattern in b
+// and returns if they match.
+func matchFunc(a, b string) bool {
+ matched, _ := regexp.MatchString(b, a)
+ return matched
+}
+
+// notMatchFunc is condition for NotMatch operator.
+// It does regexp matching of a against pattern in b
+// and returns if they do not match.
+func notMatchFunc(a, b string) bool {
+ matched, _ := regexp.MatchString(b, a)
+ return !matched
+}
+
+// If is statement for a rewrite condition.
+type If struct {
+ A string
+ Operator string
+ B string
+}
+
+// True returns true if the condition is true and false otherwise.
+// If r is not nil, it replaces placeholders before comparison.
+func (i If) True(r *dns.Msg) bool {
+ if c, ok := conditions[i.Operator]; ok {
+ a, b := i.A, i.B
+ if r != nil {
+ replacer := newReplacer(r)
+ a = replacer.Replace(i.A)
+ b = replacer.Replace(i.B)
+ }
+ return c(a, b)
+ }
+ return false
+}
+
+// NewIf creates a new If condition.
+func NewIf(a, operator, b string) (If, error) {
+ if _, ok := conditions[operator]; !ok {
+ return If{}, operatorError(operator)
+ }
+ return If{
+ A: a,
+ Operator: operator,
+ B: b,
+ }, nil
+}
diff --git a/middleware/rewrite/condition_test.go b/middleware/rewrite/condition_test.go
new file mode 100644
index 000000000..3c3b6053a
--- /dev/null
+++ b/middleware/rewrite/condition_test.go
@@ -0,0 +1,106 @@
+package rewrite
+
+import (
+ "net/http"
+ "strings"
+ "testing"
+)
+
+func TestConditions(t *testing.T) {
+ tests := []struct {
+ condition string
+ isTrue bool
+ }{
+ {"a is b", false},
+ {"a is a", true},
+ {"a not b", true},
+ {"a not a", false},
+ {"a has a", true},
+ {"a has b", false},
+ {"ba has b", true},
+ {"bab has b", true},
+ {"bab has bb", false},
+ {"a not_has a", false},
+ {"a not_has b", true},
+ {"ba not_has b", false},
+ {"bab not_has b", false},
+ {"bab not_has bb", true},
+ {"bab starts_with bb", false},
+ {"bab starts_with ba", true},
+ {"bab starts_with bab", true},
+ {"bab ends_with bb", false},
+ {"bab ends_with bab", true},
+ {"bab ends_with ab", true},
+ {"a match *", false},
+ {"a match a", true},
+ {"a match .*", true},
+ {"a match a.*", true},
+ {"a match b.*", false},
+ {"ba match b.*", true},
+ {"ba match b[a-z]", true},
+ {"b0 match b[a-z]", false},
+ {"b0a match b[a-z]", false},
+ {"b0a match b[a-z]+", false},
+ {"b0a match b[a-z0-9]+", true},
+ {"a not_match *", true},
+ {"a not_match a", false},
+ {"a not_match .*", false},
+ {"a not_match a.*", false},
+ {"a not_match b.*", true},
+ {"ba not_match b.*", false},
+ {"ba not_match b[a-z]", false},
+ {"b0 not_match b[a-z]", true},
+ {"b0a not_match b[a-z]", true},
+ {"b0a not_match b[a-z]+", true},
+ {"b0a not_match b[a-z0-9]+", false},
+ }
+
+ for i, test := range tests {
+ str := strings.Fields(test.condition)
+ ifCond, err := NewIf(str[0], str[1], str[2])
+ if err != nil {
+ t.Error(err)
+ }
+ isTrue := ifCond.True(nil)
+ if isTrue != test.isTrue {
+ t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
+ }
+ }
+
+ invalidOperators := []string{"ss", "and", "if"}
+ for _, op := range invalidOperators {
+ _, err := NewIf("a", op, "b")
+ if err == nil {
+ t.Errorf("Invalid operator %v used, expected error.", op)
+ }
+ }
+
+ replaceTests := []struct {
+ url string
+ condition string
+ isTrue bool
+ }{
+ {"/home", "{uri} match /home", true},
+ {"/hom", "{uri} match /home", false},
+ {"/hom", "{uri} starts_with /home", false},
+ {"/hom", "{uri} starts_with /h", true},
+ {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
+ {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
+ }
+
+ for i, test := range replaceTests {
+ r, err := http.NewRequest("GET", test.url, nil)
+ if err != nil {
+ t.Error(err)
+ }
+ str := strings.Fields(test.condition)
+ ifCond, err := NewIf(str[0], str[1], str[2])
+ if err != nil {
+ t.Error(err)
+ }
+ isTrue := ifCond.True(r)
+ if isTrue != test.isTrue {
+ t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
+ }
+ }
+}
diff --git a/middleware/rewrite/reverter.go b/middleware/rewrite/reverter.go
new file mode 100644
index 000000000..c3425866e
--- /dev/null
+++ b/middleware/rewrite/reverter.go
@@ -0,0 +1,38 @@
+package rewrite
+
+import "github.com/miekg/dns"
+
+// ResponseRevert reverses the operations done on the question section of a packet.
+// This is need because the client will otherwise disregards the response, i.e.
+// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN'
+type ResponseReverter struct {
+ dns.ResponseWriter
+ original dns.Question
+}
+
+func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter {
+ return &ResponseReverter{
+ ResponseWriter: w,
+ original: r.Question[0],
+ }
+}
+
+// WriteMsg records the status code and calls the
+// underlying ResponseWriter's WriteMsg method.
+func (r *ResponseReverter) WriteMsg(res *dns.Msg) error {
+ res.Question[0] = r.original
+ return r.ResponseWriter.WriteMsg(res)
+}
+
+// Write is a wrapper that records the size of the message that gets written.
+func (r *ResponseReverter) Write(buf []byte) (int, error) {
+ n, err := r.ResponseWriter.Write(buf)
+ return n, err
+}
+
+// Hijack implements dns.Hijacker. It simply wraps the underlying
+// ResponseWriter's Hijack method if there is one, or returns an error.
+func (r *ResponseReverter) Hijack() {
+ r.ResponseWriter.Hijack()
+ return
+}
diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go
new file mode 100644
index 000000000..b3039615b
--- /dev/null
+++ b/middleware/rewrite/rewrite.go
@@ -0,0 +1,223 @@
+// Package rewrite is middleware for rewriting requests internally to
+// something different.
+package rewrite
+
+import (
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// Result is the result of a rewrite
+type Result int
+
+const (
+ // RewriteIgnored is returned when rewrite is not done on request.
+ RewriteIgnored Result = iota
+ // RewriteDone is returned when rewrite is done on request.
+ RewriteDone
+ // RewriteStatus is returned when rewrite is not needed and status code should be set
+ // for the request.
+ RewriteStatus
+)
+
+// Rewrite is middleware to rewrite requests internally before being handled.
+type Rewrite struct {
+ Next middleware.Handler
+ Rules []Rule
+}
+
+// ServeHTTP implements the middleware.Handler interface.
+func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ wr := NewResponseReverter(w, r)
+ for _, rule := range rw.Rules {
+ switch result := rule.Rewrite(r); result {
+ case RewriteDone:
+ return rw.Next.ServeDNS(wr, r)
+ case RewriteIgnored:
+ break
+ case RewriteStatus:
+ // only valid for complex rules.
+ // if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
+ // return cRule.Status, nil
+ // }
+ }
+ }
+ return rw.Next.ServeDNS(w, r)
+}
+
+// Rule describes an internal location rewrite rule.
+type Rule interface {
+ // Rewrite rewrites the internal location of the current request.
+ Rewrite(*dns.Msg) Result
+}
+
+// SimpleRule is a simple rewrite rule. If the From and To look like a type
+// the type of the request is rewritten, otherwise the name is.
+// Note: TSIG signed requests will be invalid.
+type SimpleRule struct {
+ From, To string
+ fromType, toType uint16
+}
+
+// NewSimpleRule creates a new Simple Rule
+func NewSimpleRule(from, to string) SimpleRule {
+ tpf := dns.StringToType[from]
+ tpt := dns.StringToType[to]
+
+ return SimpleRule{From: from, To: to, fromType: tpf, toType: tpt}
+}
+
+// Rewrite rewrites the the current request.
+func (s SimpleRule) Rewrite(r *dns.Msg) Result {
+ if s.fromType > 0 && s.toType > 0 {
+ if r.Question[0].Qtype == s.fromType {
+ r.Question[0].Qtype = s.toType
+ return RewriteDone
+ }
+
+ }
+
+ // if the question name matches the full name, or subset rewrite that
+ // s.Question[0].Name
+ return RewriteIgnored
+}
+
+/*
+// ComplexRule is a rewrite rule based on a regular expression
+type ComplexRule struct {
+ // Path base. Request to this path and subpaths will be rewritten
+ Base string
+
+ // Path to rewrite to
+ To string
+
+ // If set, neither performs rewrite nor proceeds
+ // with request. Only returns code.
+ Status int
+
+ // Extensions to filter by
+ Exts []string
+
+ // Rewrite conditions
+ Ifs []If
+
+ *regexp.Regexp
+}
+
+// NewComplexRule creates a new RegexpRule. It returns an error if regexp
+// pattern (pattern) or extensions (ext) are invalid.
+func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
+ // validate regexp if present
+ var r *regexp.Regexp
+ if pattern != "" {
+ var err error
+ r, err = regexp.Compile(pattern)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // validate extensions if present
+ for _, v := range ext {
+ if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
+ // check if no extension is specified
+ if v != "/" && v != "!/" {
+ return nil, fmt.Errorf("invalid extension %v", v)
+ }
+ }
+ }
+
+ return &ComplexRule{
+ Base: base,
+ To: to,
+ Status: status,
+ Exts: ext,
+ Ifs: ifs,
+ Regexp: r,
+ }, nil
+}
+
+// Rewrite rewrites the internal location of the current request.
+func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) {
+ rPath := req.URL.Path
+ replacer := newReplacer(req)
+
+ // validate base
+ if !middleware.Path(rPath).Matches(r.Base) {
+ return
+ }
+
+ // validate extensions
+ if !r.matchExt(rPath) {
+ return
+ }
+
+ // validate regexp if present
+ if r.Regexp != nil {
+ // include trailing slash in regexp if present
+ start := len(r.Base)
+ if strings.HasSuffix(r.Base, "/") {
+ start--
+ }
+
+ matches := r.FindStringSubmatch(rPath[start:])
+ switch len(matches) {
+ case 0:
+ // no match
+ return
+ default:
+ // set regexp match variables {1}, {2} ...
+ for i := 1; i < len(matches); i++ {
+ replacer.Set(fmt.Sprint(i), matches[i])
+ }
+ }
+ }
+
+ // validate rewrite conditions
+ for _, i := range r.Ifs {
+ if !i.True(req) {
+ return
+ }
+ }
+
+ // if status is present, stop rewrite and return it.
+ if r.Status != 0 {
+ return RewriteStatus
+ }
+
+ // attempt rewrite
+ return To(fs, req, r.To, replacer)
+}
+
+// matchExt matches rPath against registered file extensions.
+// Returns true if a match is found and false otherwise.
+func (r *ComplexRule) matchExt(rPath string) bool {
+ f := filepath.Base(rPath)
+ ext := path.Ext(f)
+ if ext == "" {
+ ext = "/"
+ }
+
+ mustUse := false
+ for _, v := range r.Exts {
+ use := true
+ if v[0] == '!' {
+ use = false
+ v = v[1:]
+ }
+
+ if use {
+ mustUse = true
+ }
+
+ if ext == v {
+ return use
+ }
+ }
+
+ if mustUse {
+ return false
+ }
+ return true
+}
+*/
diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go
new file mode 100644
index 000000000..f57dfd602
--- /dev/null
+++ b/middleware/rewrite/rewrite_test.go
@@ -0,0 +1,159 @@
+package rewrite
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+func TestRewrite(t *testing.T) {
+ rw := Rewrite{
+ Next: middleware.HandlerFunc(urlPrinter),
+ Rules: []Rule{
+ NewSimpleRule("/from", "/to"),
+ NewSimpleRule("/a", "/b"),
+ NewSimpleRule("/b", "/b{uri}"),
+ },
+ FileSys: http.Dir("."),
+ }
+
+ regexps := [][]string{
+ {"/reg/", ".*", "/to", ""},
+ {"/r/", "[a-z]+", "/toaz", "!.html|"},
+ {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
+ {"/ab/", "ab", "/ab?{query}", ".txt|"},
+ {"/ab/", "ab", "/ab?type=html&{query}", ".html|"},
+ {"/abc/", "ab", "/abc/{file}", ".html|"},
+ {"/abcd/", "ab", "/a/{dir}/{file}", ".html|"},
+ {"/abcde/", "ab", "/a#{fragment}", ".html|"},
+ {"/ab/", `.*\.jpg`, "/ajpg", ""},
+ {"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""},
+ {"/reg2grp", `(.*)`, "/{1}", ""},
+ {"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""},
+ }
+
+ for _, regexpRule := range regexps {
+ var ext []string
+ if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
+ ext = s[:len(s)-1]
+ }
+ rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rw.Rules = append(rw.Rules, rule)
+ }
+
+ tests := []struct {
+ from string
+ expectedTo string
+ }{
+ {"/from", "/to"},
+ {"/a", "/b"},
+ {"/b", "/b/b"},
+ {"/aa", "/aa"},
+ {"/", "/"},
+ {"/a?foo=bar", "/b?foo=bar"},
+ {"/asdf?foo=bar", "/asdf?foo=bar"},
+ {"/foo#bar", "/foo#bar"},
+ {"/a#foo", "/b#foo"},
+ {"/reg/foo", "/to"},
+ {"/re", "/re"},
+ {"/r/", "/r/"},
+ {"/r/123", "/r/123"},
+ {"/r/a123", "/toaz"},
+ {"/r/abcz", "/toaz"},
+ {"/r/z", "/toaz"},
+ {"/r/z.html", "/r/z.html"},
+ {"/r/z.js", "/toaz"},
+ {"/url/asAB", "/to/url/asAB"},
+ {"/url/aBsAB", "/url/aBsAB"},
+ {"/url/a00sAB", "/to/url/a00sAB"},
+ {"/url/a0z0sAB", "/to/url/a0z0sAB"},
+ {"/ab/aa", "/ab/aa"},
+ {"/ab/ab", "/ab/ab"},
+ {"/ab/ab.txt", "/ab"},
+ {"/ab/ab.txt?name=name", "/ab?name=name"},
+ {"/ab/ab.html?name=name", "/ab?type=html&name=name"},
+ {"/abc/ab.html", "/abc/ab.html"},
+ {"/abcd/abcd.html", "/a/abcd/abcd.html"},
+ {"/abcde/abcde.html", "/a"},
+ {"/abcde/abcde.html#1234", "/a#1234"},
+ {"/ab/ab.jpg", "/ajpg"},
+ {"/reggrp/ad/12", "/a12"},
+ {"/reggrp/ad/124a", "/a124/a"},
+ {"/reggrp/ad/124abc", "/a124/abc"},
+ {"/reg2grp/ad/124abc", "/ad/124abc"},
+ {"/reg3grp/ad/aa/66", "/adaa66"},
+ {"/reg3grp/ad612/n1n/ab", "/ad612n1nab"},
+ }
+
+ for i, test := range tests {
+ req, err := http.NewRequest("GET", test.from, nil)
+ if err != nil {
+ t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
+ }
+
+ rec := httptest.NewRecorder()
+ rw.ServeHTTP(rec, req)
+
+ if rec.Body.String() != test.expectedTo {
+ t.Errorf("Test %d: Expected URL to be '%s' but was '%s'",
+ i, test.expectedTo, rec.Body.String())
+ }
+ }
+
+ statusTests := []struct {
+ status int
+ base string
+ to string
+ regexp string
+ statusExpected bool
+ }{
+ {400, "/status", "", "", true},
+ {400, "/ignore", "", "", false},
+ {400, "/", "", "^/ignore", false},
+ {400, "/", "", "(.*)", true},
+ {400, "/status", "", "", true},
+ }
+
+ for i, s := range statusTests {
+ urlPath := fmt.Sprintf("/status%d", i)
+ rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil)
+ if err != nil {
+ t.Fatalf("Test %d: No error expected for rule but found %v", i, err)
+ }
+ rw.Rules = []Rule{rule}
+ req, err := http.NewRequest("GET", urlPath, nil)
+ if err != nil {
+ t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
+ }
+
+ rec := httptest.NewRecorder()
+ code, err := rw.ServeHTTP(rec, req)
+ if err != nil {
+ t.Fatalf("Test %d: No error expected for handler but found %v", i, err)
+ }
+ if s.statusExpected {
+ if rec.Body.String() != "" {
+ t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String())
+ }
+ if code != s.status {
+ t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code)
+ }
+ } else {
+ if code != 0 {
+ t.Errorf("Test %d: Expected no status code found %d", i, code)
+ }
+ }
+ }
+}
+
+func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
+ fmt.Fprintf(w, r.URL.String())
+ return 0, nil
+}
diff --git a/middleware/rewrite/testdata/testdir/empty b/middleware/rewrite/testdata/testdir/empty
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/middleware/rewrite/testdata/testdir/empty
diff --git a/middleware/rewrite/testdata/testfile b/middleware/rewrite/testdata/testfile
new file mode 100644
index 000000000..7b4d68d70
--- /dev/null
+++ b/middleware/rewrite/testdata/testfile
@@ -0,0 +1 @@
+empty \ No newline at end of file
diff --git a/middleware/roller.go b/middleware/roller.go
new file mode 100644
index 000000000..995cabf91
--- /dev/null
+++ b/middleware/roller.go
@@ -0,0 +1,27 @@
+package middleware
+
+import (
+ "io"
+
+ "gopkg.in/natefinch/lumberjack.v2"
+)
+
+// LogRoller implements a middleware that provides a rolling logger.
+type LogRoller struct {
+ Filename string
+ MaxSize int
+ MaxAge int
+ MaxBackups int
+ LocalTime bool
+}
+
+// GetLogWriter returns an io.Writer that writes to a rolling logger.
+func (l LogRoller) GetLogWriter() io.Writer {
+ return &lumberjack.Logger{
+ Filename: l.Filename,
+ MaxSize: l.MaxSize,
+ MaxAge: l.MaxAge,
+ MaxBackups: l.MaxBackups,
+ LocalTime: l.LocalTime,
+ }
+}
diff --git a/middleware/zone.go b/middleware/zone.go
new file mode 100644
index 000000000..6798bca8e
--- /dev/null
+++ b/middleware/zone.go
@@ -0,0 +1,21 @@
+package middleware
+
+import "strings"
+
+type Zones []string
+
+// Matches checks to see if other matches p.
+// The match will return the most specific zones
+// that matches other. The empty string signals a not found
+// condition.
+func (z Zones) Matches(qname string) string {
+ zone := ""
+ for _, zname := range z {
+ if strings.HasSuffix(qname, zname) {
+ if len(zname) > len(zone) {
+ zone = zname
+ }
+ }
+ }
+ return zone
+}