From 967e149875e8c1b9aaa47dadd7b6ee55d7fe5407 Mon Sep 17 00:00:00 2001 From: Noah Petherbridge Date: Sun, 21 Aug 2022 14:17:52 -0700 Subject: [PATCH] Optimize CurrentUser to read from DB only once per request --- pkg/middleware/authentication.go | 20 +++++++++++++++----- pkg/session/current_user.go | 6 ++++++ pkg/session/session.go | 5 +++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/pkg/middleware/authentication.go b/pkg/middleware/authentication.go index bf6e9b2..44d02fe 100644 --- a/pkg/middleware/authentication.go +++ b/pkg/middleware/authentication.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "net/http" "time" @@ -46,7 +47,9 @@ func LoginRequired(handler http.Handler) http.Handler { } } - handler.ServeHTTP(w, r) + // Stick the CurrentUser in the request context so future calls to session.CurrentUser can read it. + ctx := context.WithValue(r.Context(), session.CurrentUserKey, user) + handler.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -55,19 +58,26 @@ func AdminRequired(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // User must be logged in. - if currentUser, err := session.CurrentUser(r); err != nil { + currentUser, err := session.CurrentUser(r) + if err != nil { log.Error("AdminRequired: %s", err) errhandler := templates.MakeErrorPage("Login Required", "You must be signed in to view this page.", http.StatusForbidden) errhandler.ServeHTTP(w, r) return - } else if !currentUser.IsAdmin { + } + + // Stick the CurrentUser in the request context so future calls to session.CurrentUser can read it. + ctx := context.WithValue(r.Context(), session.CurrentUserKey, currentUser) + + // Admin required. + if !currentUser.IsAdmin { log.Error("AdminRequired: %s", err) errhandler := templates.MakeErrorPage("Admin Required", "You do not have permission for this page.", http.StatusForbidden) - errhandler.ServeHTTP(w, r) + errhandler.ServeHTTP(w, r.WithContext(ctx)) return } - handler.ServeHTTP(w, r) + handler.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/pkg/session/current_user.go b/pkg/session/current_user.go index 5105a0f..a66fd2a 100644 --- a/pkg/session/current_user.go +++ b/pkg/session/current_user.go @@ -11,6 +11,12 @@ import ( func CurrentUser(r *http.Request) (*models.User, error) { sess := Get(r) if sess.LoggedIn { + // Did we already get the CurrentUser once before? + ctx := r.Context() + if user, ok := ctx.Value(CurrentUserKey).(*models.User); ok { + return user, nil + } + // Load the associated user ID. return models.GetUser(sess.UserID) } diff --git a/pkg/session/session.go b/pkg/session/session.go index 1603cf2..a98ac7c 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -26,8 +26,9 @@ type Session struct { } const ( - ContextKey = "session" - CSRFKey = "csrf" + ContextKey = "session" + CurrentUserKey = "current_user" + CSRFKey = "csrf" ) // New creates a blank session object.