diff options
author | 2016-03-18 20:57:35 +0000 | |
---|---|---|
committer | 2016-03-18 20:57:35 +0000 | |
commit | 3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d (patch) | |
tree | fae74c33cfed05de603785294593275f1901c861 /middleware/errors | |
download | coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.gz coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.zst coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.zip |
First commit
Diffstat (limited to 'middleware/errors')
-rw-r--r-- | middleware/errors/errors.go | 100 | ||||
-rw-r--r-- | middleware/errors/errors_test.go | 168 |
2 files changed, 268 insertions, 0 deletions
diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go new file mode 100644 index 000000000..bf5bc7aae --- /dev/null +++ b/middleware/errors/errors.go @@ -0,0 +1,100 @@ +// Package errors implements an HTTP error handling middleware. +package errors + +import ( + "fmt" + "log" + "runtime" + "strings" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// ErrorHandler handles DNS errors (and errors from other middleware). +type ErrorHandler struct { + Next middleware.Handler + LogFile string + Log *log.Logger + LogRoller *middleware.LogRoller + Debug bool // if true, errors are written out to client rather than to a log +} + +func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + defer h.recovery(w, r) + + rcode, err := h.Next.ServeDNS(w, r) + + if err != nil { + errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err) + + if h.Debug { + // Write error to response as a txt message instead of to log + answer := debugMsg(rcode, r) + txt, _ := dns.NewRR(". IN 0 TXT " + errMsg) + answer.Answer = append(answer.Answer, txt) + w.WriteMsg(answer) + return 0, err + } + h.Log.Println(errMsg) + } + + return rcode, err +} + +func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) { + rec := recover() + if rec == nil { + return + } + + // Obtain source of panic + // From: https://gist.github.com/swdunlop/9629168 + var name, file string // function name, file name + var line int + var pc [16]uintptr + n := runtime.Callers(3, pc[:]) + for _, pc := range pc[:n] { + fn := runtime.FuncForPC(pc) + if fn == nil { + continue + } + file, line = fn.FileLine(pc) + name = fn.Name() + if !strings.HasPrefix(name, "runtime.") { + break + } + } + + // Trim file path + delim := "/coredns/" + pkgPathPos := strings.Index(file, delim) + if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) { + file = file[pkgPathPos+len(delim):] + } + + panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec) + if h.Debug { + // Write error and stack trace to the response rather than to a log + var stackBuf [4096]byte + stack := stackBuf[:runtime.Stack(stackBuf[:], false)] + answer := debugMsg(dns.RcodeServerFailure, r) + // add stack buf in TXT, limited to 255 chars for now. + txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255])) + answer.Answer = append(answer.Answer, txt) + w.WriteMsg(answer) + } else { + // Currently we don't use the function name, since file:line is more conventional + h.Log.Printf(panicMsg) + } +} + +// debugMsg creates a debug message that gets send back to the client. +func debugMsg(rcode int, r *dns.Msg) *dns.Msg { + answer := new(dns.Msg) + answer.SetRcode(r, rcode) + return answer +} + +const timeFormat = "02/Jan/2006:15:04:05 -0700" diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go new file mode 100644 index 000000000..4434e835c --- /dev/null +++ b/middleware/errors/errors_test.go @@ -0,0 +1,168 @@ +package errors + +import ( + "bytes" + "errors" + "fmt" + "log" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/miekg/coredns/middleware" +) + +func TestErrors(t *testing.T) { + // create a temporary page + path := filepath.Join(os.TempDir(), "errors_test.html") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + defer os.Remove(path) + + const content = "This is a error page" + _, err = f.WriteString(content) + if err != nil { + t.Fatal(err) + } + f.Close() + + buf := bytes.Buffer{} + em := ErrorHandler{ + ErrorPages: map[int]string{ + http.StatusNotFound: path, + http.StatusForbidden: "not_exist_file", + }, + Log: log.New(&buf, "", 0), + } + _, notExistErr := os.Open("not_exist_file") + + testErr := errors.New("test error") + tests := []struct { + next middleware.Handler + expectedCode int + expectedBody string + expectedLog string + expectedErr error + }{ + { + next: genErrorHandler(http.StatusOK, nil, "normal"), + expectedCode: http.StatusOK, + expectedBody: "normal", + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusMovedPermanently, testErr, ""), + expectedCode: http.StatusMovedPermanently, + expectedBody: "", + expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr), + expectedErr: testErr, + }, + { + next: genErrorHandler(http.StatusBadRequest, nil, ""), + expectedCode: 0, + expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest, + http.StatusText(http.StatusBadRequest)), + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusNotFound, nil, ""), + expectedCode: 0, + expectedBody: content, + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusForbidden, nil, ""), + expectedCode: 0, + expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden, + http.StatusText(http.StatusForbidden)), + expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n", + http.StatusForbidden, notExistErr), + expectedErr: nil, + }, + } + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + for i, test := range tests { + em.Next = test.next + buf.Reset() + rec := httptest.NewRecorder() + code, err := em.ServeHTTP(rec, req) + + if err != test.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", + i, test.expectedErr, err) + } + if code != test.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", + i, test.expectedCode, code) + } + if body := rec.Body.String(); body != test.expectedBody { + t.Errorf("Test %d: Expected body %q, but got %q", + i, test.expectedBody, body) + } + if log := buf.String(); !strings.Contains(log, test.expectedLog) { + t.Errorf("Test %d: Expected log %q, but got %q", + i, test.expectedLog, log) + } + } +} + +func TestVisibleErrorWithPanic(t *testing.T) { + const panicMsg = "I'm a panic" + eh := ErrorHandler{ + ErrorPages: make(map[int]string), + Debug: true, + Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + panic(panicMsg) + }), + } + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + rec := httptest.NewRecorder() + + code, err := eh.ServeHTTP(rec, req) + + if code != 0 { + t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code) + } + if err != nil { + t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err) + } + + body := rec.Body.String() + + if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") { + t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body) + } + if !strings.Contains(body, panicMsg) { + t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body) + } + if len(body) < 500 { + t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body)) + } +} + +func genErrorHandler(status int, err error, body string) middleware.Handler { + return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + if len(body) > 0 { + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + fmt.Fprint(w, body) + } + return status, err + }) +} |