aboutsummaryrefslogtreecommitdiff
path: root/middleware/errors
diff options
context:
space:
mode:
authorGravatar Miek Gieben <miek@miek.nl> 2016-03-18 20:57:35 +0000
committerGravatar Miek Gieben <miek@miek.nl> 2016-03-18 20:57:35 +0000
commit3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d (patch)
treefae74c33cfed05de603785294593275f1901c861 /middleware/errors
downloadcoredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.gz
coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.zst
coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.zip
First commit
Diffstat (limited to 'middleware/errors')
-rw-r--r--middleware/errors/errors.go100
-rw-r--r--middleware/errors/errors_test.go168
2 files changed, 268 insertions, 0 deletions
diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go
new file mode 100644
index 000000000..bf5bc7aae
--- /dev/null
+++ b/middleware/errors/errors.go
@@ -0,0 +1,100 @@
+// Package errors implements an HTTP error handling middleware.
+package errors
+
+import (
+ "fmt"
+ "log"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// ErrorHandler handles DNS errors (and errors from other middleware).
+type ErrorHandler struct {
+ Next middleware.Handler
+ LogFile string
+ Log *log.Logger
+ LogRoller *middleware.LogRoller
+ Debug bool // if true, errors are written out to client rather than to a log
+}
+
+func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ defer h.recovery(w, r)
+
+ rcode, err := h.Next.ServeDNS(w, r)
+
+ if err != nil {
+ errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
+
+ if h.Debug {
+ // Write error to response as a txt message instead of to log
+ answer := debugMsg(rcode, r)
+ txt, _ := dns.NewRR(". IN 0 TXT " + errMsg)
+ answer.Answer = append(answer.Answer, txt)
+ w.WriteMsg(answer)
+ return 0, err
+ }
+ h.Log.Println(errMsg)
+ }
+
+ return rcode, err
+}
+
+func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) {
+ rec := recover()
+ if rec == nil {
+ return
+ }
+
+ // Obtain source of panic
+ // From: https://gist.github.com/swdunlop/9629168
+ var name, file string // function name, file name
+ var line int
+ var pc [16]uintptr
+ n := runtime.Callers(3, pc[:])
+ for _, pc := range pc[:n] {
+ fn := runtime.FuncForPC(pc)
+ if fn == nil {
+ continue
+ }
+ file, line = fn.FileLine(pc)
+ name = fn.Name()
+ if !strings.HasPrefix(name, "runtime.") {
+ break
+ }
+ }
+
+ // Trim file path
+ delim := "/coredns/"
+ pkgPathPos := strings.Index(file, delim)
+ if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) {
+ file = file[pkgPathPos+len(delim):]
+ }
+
+ panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec)
+ if h.Debug {
+ // Write error and stack trace to the response rather than to a log
+ var stackBuf [4096]byte
+ stack := stackBuf[:runtime.Stack(stackBuf[:], false)]
+ answer := debugMsg(dns.RcodeServerFailure, r)
+ // add stack buf in TXT, limited to 255 chars for now.
+ txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255]))
+ answer.Answer = append(answer.Answer, txt)
+ w.WriteMsg(answer)
+ } else {
+ // Currently we don't use the function name, since file:line is more conventional
+ h.Log.Printf(panicMsg)
+ }
+}
+
+// debugMsg creates a debug message that gets send back to the client.
+func debugMsg(rcode int, r *dns.Msg) *dns.Msg {
+ answer := new(dns.Msg)
+ answer.SetRcode(r, rcode)
+ return answer
+}
+
+const timeFormat = "02/Jan/2006:15:04:05 -0700"
diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go
new file mode 100644
index 000000000..4434e835c
--- /dev/null
+++ b/middleware/errors/errors_test.go
@@ -0,0 +1,168 @@
+package errors
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+func TestErrors(t *testing.T) {
+ // create a temporary page
+ path := filepath.Join(os.TempDir(), "errors_test.html")
+ f, err := os.Create(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(path)
+
+ const content = "This is a error page"
+ _, err = f.WriteString(content)
+ if err != nil {
+ t.Fatal(err)
+ }
+ f.Close()
+
+ buf := bytes.Buffer{}
+ em := ErrorHandler{
+ ErrorPages: map[int]string{
+ http.StatusNotFound: path,
+ http.StatusForbidden: "not_exist_file",
+ },
+ Log: log.New(&buf, "", 0),
+ }
+ _, notExistErr := os.Open("not_exist_file")
+
+ testErr := errors.New("test error")
+ tests := []struct {
+ next middleware.Handler
+ expectedCode int
+ expectedBody string
+ expectedLog string
+ expectedErr error
+ }{
+ {
+ next: genErrorHandler(http.StatusOK, nil, "normal"),
+ expectedCode: http.StatusOK,
+ expectedBody: "normal",
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusMovedPermanently, testErr, ""),
+ expectedCode: http.StatusMovedPermanently,
+ expectedBody: "",
+ expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr),
+ expectedErr: testErr,
+ },
+ {
+ next: genErrorHandler(http.StatusBadRequest, nil, ""),
+ expectedCode: 0,
+ expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest,
+ http.StatusText(http.StatusBadRequest)),
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusNotFound, nil, ""),
+ expectedCode: 0,
+ expectedBody: content,
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusForbidden, nil, ""),
+ expectedCode: 0,
+ expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden,
+ http.StatusText(http.StatusForbidden)),
+ expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n",
+ http.StatusForbidden, notExistErr),
+ expectedErr: nil,
+ },
+ }
+
+ req, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i, test := range tests {
+ em.Next = test.next
+ buf.Reset()
+ rec := httptest.NewRecorder()
+ code, err := em.ServeHTTP(rec, req)
+
+ if err != test.expectedErr {
+ t.Errorf("Test %d: Expected error %v, but got %v",
+ i, test.expectedErr, err)
+ }
+ if code != test.expectedCode {
+ t.Errorf("Test %d: Expected status code %d, but got %d",
+ i, test.expectedCode, code)
+ }
+ if body := rec.Body.String(); body != test.expectedBody {
+ t.Errorf("Test %d: Expected body %q, but got %q",
+ i, test.expectedBody, body)
+ }
+ if log := buf.String(); !strings.Contains(log, test.expectedLog) {
+ t.Errorf("Test %d: Expected log %q, but got %q",
+ i, test.expectedLog, log)
+ }
+ }
+}
+
+func TestVisibleErrorWithPanic(t *testing.T) {
+ const panicMsg = "I'm a panic"
+ eh := ErrorHandler{
+ ErrorPages: make(map[int]string),
+ Debug: true,
+ Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
+ panic(panicMsg)
+ }),
+ }
+
+ req, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rec := httptest.NewRecorder()
+
+ code, err := eh.ServeHTTP(rec, req)
+
+ if code != 0 {
+ t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code)
+ }
+ if err != nil {
+ t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err)
+ }
+
+ body := rec.Body.String()
+
+ if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") {
+ t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body)
+ }
+ if !strings.Contains(body, panicMsg) {
+ t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body)
+ }
+ if len(body) < 500 {
+ t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body))
+ }
+}
+
+func genErrorHandler(status int, err error, body string) middleware.Handler {
+ return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
+ if len(body) > 0 {
+ w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+ fmt.Fprint(w, body)
+ }
+ return status, err
+ })
+}