diff options
-rw-r--r-- | go.mod | 1 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | handler.go | 41 | ||||
-rw-r--r-- | handler_test.go | 46 |
4 files changed, 75 insertions, 15 deletions
@@ -3,7 +3,6 @@ module go.uber.org/sally go 1.18 require ( - github.com/julienschmidt/httprouter v1.3.0 github.com/stretchr/testify v1.8.1 golang.org/x/net v0.5.0 gopkg.in/yaml.v2 v2.4.0 @@ -1,8 +1,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= -github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -4,8 +4,8 @@ import ( "fmt" "html/template" "net/http" + "strings" - "github.com/julienschmidt/httprouter" "go.uber.org/sally/templates" ) @@ -18,29 +18,36 @@ var ( // CreateHandler creates a Sally http.Handler func CreateHandler(config *Config) http.Handler { - router := httprouter.New() - router.RedirectTrailingSlash = false - - router.GET("/", indexHandler{config: config}.Handle) + mux := http.NewServeMux() + mux.Handle("/", &indexHandler{config: config}) for name, pkg := range config.Packages { handle := packageHandler{ pkgName: name, pkg: pkg, config: config, - }.Handle - router.GET(fmt.Sprintf("/%s", name), handle) - router.GET(fmt.Sprintf("/%s/*path", name), handle) + } + // Double-register so that "/foo" + // does not redirect to "/foo/" with a 300. + mux.Handle("/"+name, &handle) + mux.Handle("/"+name+"/", &handle) } - return router + return mux } type indexHandler struct { config *Config } -func (h indexHandler) Handle(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Index handler only supports '/'. + // ServeMux will call us for any '/foo' that is not a known package. + if r.Method != http.MethodGet || r.URL.Path != "/" { + http.NotFound(w, r) + return + } + if err := indexTemplate.Execute(w, h.config); err != nil { http.Error(w, err.Error(), 500) } @@ -52,7 +59,17 @@ type packageHandler struct { config *Config } -func (h packageHandler) Handle(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *packageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + + // Extract the relative path to subpackages, if any. + // "/foo/bar" => "/bar" + // "/foo" => "" + relPath := strings.TrimPrefix(r.URL.Path, "/"+h.pkgName) + baseURL := h.config.URL if h.pkg.URL != "" { baseURL = h.pkg.URL @@ -67,7 +84,7 @@ func (h packageHandler) Handle(w http.ResponseWriter, r *http.Request, ps httpro Repo: h.pkg.Repo, Branch: h.pkg.Branch, CanonicalURL: canonicalURL, - GodocURL: fmt.Sprintf("https://%s/%s%s", h.config.Godoc.Host, canonicalURL, ps.ByName("path")), + GodocURL: fmt.Sprintf("https://%s/%s%s", h.config.Godoc.Host, canonicalURL, relPath), } if err := packageTemplate.Execute(w, data); err != nil { http.Error(w, err.Error(), 500) diff --git a/handler_test.go b/handler_test.go index e300d19..b1d5d2e 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,9 +1,14 @@ package main import ( + "io" + "net/http" + "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var config = ` @@ -120,3 +125,44 @@ func TestPackageLevelURL(t *testing.T) { </html> `) } + +func TestPostRejected(t *testing.T) { + t.Parallel() + + h := CreateHandler(&Config{ + URL: "go.uberalt.org", + Packages: map[string]Package{ + "zap": { + Repo: "github.com/uber-go/zap", + }, + }, + }) + srv := httptest.NewServer(h) + t.Cleanup(srv.Close) + + tests := []struct { + desc string + path string + }{ + {desc: "index", path: "/"}, + {desc: "package", path: "/zap"}, + {desc: "subpackage", path: "/zap/zapcore"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + res, err := http.Post(srv.URL+tt.path, "text/plain", strings.NewReader("foo")) + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusNotFound, res.StatusCode, + "expected 404, got:\n%s", string(body)) + }) + } +} |