aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--go.mod1
-rw-r--r--go.sum2
-rw-r--r--handler.go41
-rw-r--r--handler_test.go46
4 files changed, 75 insertions, 15 deletions
diff --git a/go.mod b/go.mod
index 58f7710..f118570 100644
--- a/go.mod
+++ b/go.mod
@@ -3,7 +3,6 @@ module go.uber.org/sally
go 1.18
require (
- github.com/julienschmidt/httprouter v1.3.0
github.com/stretchr/testify v1.8.1
golang.org/x/net v0.5.0
gopkg.in/yaml.v2 v2.4.0
diff --git a/go.sum b/go.sum
index 8261753..e8b209a 100644
--- a/go.sum
+++ b/go.sum
@@ -1,8 +1,6 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
-github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
diff --git a/handler.go b/handler.go
index 0490a83..bbefd30 100644
--- a/handler.go
+++ b/handler.go
@@ -4,8 +4,8 @@ import (
"fmt"
"html/template"
"net/http"
+ "strings"
- "github.com/julienschmidt/httprouter"
"go.uber.org/sally/templates"
)
@@ -18,29 +18,36 @@ var (
// CreateHandler creates a Sally http.Handler
func CreateHandler(config *Config) http.Handler {
- router := httprouter.New()
- router.RedirectTrailingSlash = false
-
- router.GET("/", indexHandler{config: config}.Handle)
+ mux := http.NewServeMux()
+ mux.Handle("/", &indexHandler{config: config})
for name, pkg := range config.Packages {
handle := packageHandler{
pkgName: name,
pkg: pkg,
config: config,
- }.Handle
- router.GET(fmt.Sprintf("/%s", name), handle)
- router.GET(fmt.Sprintf("/%s/*path", name), handle)
+ }
+ // Double-register so that "/foo"
+ // does not redirect to "/foo/" with a 300.
+ mux.Handle("/"+name, &handle)
+ mux.Handle("/"+name+"/", &handle)
}
- return router
+ return mux
}
type indexHandler struct {
config *Config
}
-func (h indexHandler) Handle(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Index handler only supports '/'.
+ // ServeMux will call us for any '/foo' that is not a known package.
+ if r.Method != http.MethodGet || r.URL.Path != "/" {
+ http.NotFound(w, r)
+ return
+ }
+
if err := indexTemplate.Execute(w, h.config); err != nil {
http.Error(w, err.Error(), 500)
}
@@ -52,7 +59,17 @@ type packageHandler struct {
config *Config
}
-func (h packageHandler) Handle(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *packageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.NotFound(w, r)
+ return
+ }
+
+ // Extract the relative path to subpackages, if any.
+ // "/foo/bar" => "/bar"
+ // "/foo" => ""
+ relPath := strings.TrimPrefix(r.URL.Path, "/"+h.pkgName)
+
baseURL := h.config.URL
if h.pkg.URL != "" {
baseURL = h.pkg.URL
@@ -67,7 +84,7 @@ func (h packageHandler) Handle(w http.ResponseWriter, r *http.Request, ps httpro
Repo: h.pkg.Repo,
Branch: h.pkg.Branch,
CanonicalURL: canonicalURL,
- GodocURL: fmt.Sprintf("https://%s/%s%s", h.config.Godoc.Host, canonicalURL, ps.ByName("path")),
+ GodocURL: fmt.Sprintf("https://%s/%s%s", h.config.Godoc.Host, canonicalURL, relPath),
}
if err := packageTemplate.Execute(w, data); err != nil {
http.Error(w, err.Error(), 500)
diff --git a/handler_test.go b/handler_test.go
index e300d19..b1d5d2e 100644
--- a/handler_test.go
+++ b/handler_test.go
@@ -1,9 +1,14 @@
package main
import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
var config = `
@@ -120,3 +125,44 @@ func TestPackageLevelURL(t *testing.T) {
</html>
`)
}
+
+func TestPostRejected(t *testing.T) {
+ t.Parallel()
+
+ h := CreateHandler(&Config{
+ URL: "go.uberalt.org",
+ Packages: map[string]Package{
+ "zap": {
+ Repo: "github.com/uber-go/zap",
+ },
+ },
+ })
+ srv := httptest.NewServer(h)
+ t.Cleanup(srv.Close)
+
+ tests := []struct {
+ desc string
+ path string
+ }{
+ {desc: "index", path: "/"},
+ {desc: "package", path: "/zap"},
+ {desc: "subpackage", path: "/zap/zapcore"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.desc, func(t *testing.T) {
+ t.Parallel()
+
+ res, err := http.Post(srv.URL+tt.path, "text/plain", strings.NewReader("foo"))
+ require.NoError(t, err)
+ defer res.Body.Close()
+
+ body, err := io.ReadAll(res.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusNotFound, res.StatusCode,
+ "expected 404, got:\n%s", string(body))
+ })
+ }
+}