diff options
Diffstat (limited to 'server/middleware/session.go')
-rw-r--r-- | server/middleware/session.go | 68 |
1 files changed, 37 insertions, 31 deletions
diff --git a/server/middleware/session.go b/server/middleware/session.go index 3759565f..2891b68a 100644 --- a/server/middleware/session.go +++ b/server/middleware/session.go @@ -10,60 +10,66 @@ import ( "github.com/miniflux/miniflux/logger" "github.com/miniflux/miniflux/model" - "github.com/miniflux/miniflux/server/route" + "github.com/miniflux/miniflux/server/cookie" "github.com/miniflux/miniflux/storage" - - "github.com/gorilla/mux" ) // SessionMiddleware represents a session middleware. type SessionMiddleware struct { - store *storage.Storage - router *mux.Router + store *storage.Storage } // Handler execute the middleware. -func (s *SessionMiddleware) Handler(next http.Handler) http.Handler { +func (t *SessionMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := s.getSessionFromCookie(r) + var err error + session := t.getSessionValueFromCookie(r) if session == nil { logger.Debug("[Middleware:Session] Session not found") - if s.isPublicRoute(r) { - next.ServeHTTP(w, r) - } else { - http.Redirect(w, r, route.Path(s.router, "login"), http.StatusFound) + session, err = t.store.CreateSession() + if err != nil { + logger.Error("[Middleware:Session] %v", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } + + http.SetCookie(w, cookie.New(cookie.CookieSessionID, session.ID, r.URL.Scheme == "https")) } else { logger.Debug("[Middleware:Session] %s", session) - ctx := r.Context() - ctx = context.WithValue(ctx, UserIDContextKey, session.UserID) - ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true) + } - next.ServeHTTP(w, r.WithContext(ctx)) + if r.Method == "POST" { + formValue := r.FormValue("csrf") + headerValue := r.Header.Get("X-Csrf-Token") + + if session.Data.CSRF != formValue && session.Data.CSRF != headerValue { + logger.Error(`[Middleware:Session] Invalid or missing CSRF token: Form="%s", Header="%s"`, formValue, headerValue) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Invalid or missing CSRF session!")) + return + } } - }) -} -func (s *SessionMiddleware) isPublicRoute(r *http.Request) bool { - route := mux.CurrentRoute(r) - switch route.GetName() { - case "login", "checkLogin", "stylesheet", "javascript", "oauth2Redirect", "oauth2Callback", "appIcon", "favicon": - return true - default: - return false - } + ctx := r.Context() + ctx = context.WithValue(ctx, SessionIDContextKey, session.ID) + ctx = context.WithValue(ctx, CSRFContextKey, session.Data.CSRF) + ctx = context.WithValue(ctx, OAuth2StateContextKey, session.Data.OAuth2State) + ctx = context.WithValue(ctx, FlashMessageContextKey, session.Data.FlashMessage) + ctx = context.WithValue(ctx, FlashErrorMessageContextKey, session.Data.FlashErrorMessage) + next.ServeHTTP(w, r.WithContext(ctx)) + }) } -func (s *SessionMiddleware) getSessionFromCookie(r *http.Request) *model.UserSession { - sessionCookie, err := r.Cookie("sessionID") +func (t *SessionMiddleware) getSessionValueFromCookie(r *http.Request) *model.Session { + sessionCookie, err := r.Cookie(cookie.CookieSessionID) if err == http.ErrNoCookie { return nil } - session, err := s.store.UserSessionByToken(sessionCookie.Value) + session, err := t.store.Session(sessionCookie.Value) if err != nil { - logger.Error("[SessionMiddleware] %v", err) + logger.Error("[Middleware:Session] %v", err) return nil } @@ -71,6 +77,6 @@ func (s *SessionMiddleware) getSessionFromCookie(r *http.Request) *model.UserSes } // NewSessionMiddleware returns a new SessionMiddleware. -func NewSessionMiddleware(s *storage.Storage, r *mux.Router) *SessionMiddleware { - return &SessionMiddleware{store: s, router: r} +func NewSessionMiddleware(s *storage.Storage) *SessionMiddleware { + return &SessionMiddleware{store: s} } |