diff --git a/core/admin.go b/core/admin.go index 8b142db..91173ea 100644 --- a/core/admin.go +++ b/core/admin.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/mux" "github.com/kirsle/blog/core/internal/forms" + "github.com/kirsle/blog/core/internal/middleware/auth" "github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/render" "github.com/urfave/negroni" @@ -24,7 +25,7 @@ func (b *Blog) AdminRoutes(r *mux.Router) { adminRouter.HandleFunc("/editor", b.EditorHandler) // r.HandleFunc("/admin", b.AdminHandler) r.PathPrefix("/admin").Handler(negroni.New( - negroni.HandlerFunc(b.LoginRequired), + negroni.HandlerFunc(auth.LoginRequired(b.MustLogin)), negroni.Wrap(adminRouter), )) } @@ -48,8 +49,8 @@ func (b *Blog) EditorHandler(w http.ResponseWriter, r *http.Request) { var ( fp string fromCore = r.FormValue("from") == "core" - saving = r.FormValue("action") == "save" - deleting = r.FormValue("action") == "delete" + saving = r.FormValue("action") == ActionSave + deleting = r.FormValue("action") == ActionDelete body = []byte{} ) diff --git a/core/auth.go b/core/auth.go index d601480..97428fa 100644 --- a/core/auth.go +++ b/core/auth.go @@ -7,7 +7,9 @@ import ( "github.com/gorilla/mux" "github.com/kirsle/blog/core/internal/forms" "github.com/kirsle/blog/core/internal/log" + "github.com/kirsle/blog/core/internal/middleware/auth" "github.com/kirsle/blog/core/internal/models/users" + "github.com/kirsle/blog/core/internal/sessions" ) // AuthRoutes attaches the auth routes to the app. @@ -17,9 +19,16 @@ func (b *Blog) AuthRoutes(r *mux.Router) { r.HandleFunc("/account", b.AccountHandler) } +// MustLogin handles errors from the LoginRequired middleware by redirecting +// the user to the login page. +func (b *Blog) MustLogin(w http.ResponseWriter, r *http.Request) { + log.Info("MustLogin for %s", r.URL.Path) + b.Redirect(w, "/login?next="+r.URL.Path) +} + // Login logs the browser in as the given user. func (b *Blog) Login(w http.ResponseWriter, r *http.Request, u *users.User) error { - session, err := b.store.Get(r, "session") // TODO session name + session, err := sessions.Store.Get(r, "session") // TODO session name if err != nil { return err } @@ -78,7 +87,7 @@ func (b *Blog) LoginHandler(w http.ResponseWriter, r *http.Request) { // LogoutHandler logs the user out and redirects to the home page. func (b *Blog) LogoutHandler(w http.ResponseWriter, r *http.Request) { - session, _ := b.store.Get(r, "session") + session, _ := sessions.Store.Get(r, "session") delete(session.Values, "logged-in") delete(session.Values, "user-id") session.Save(r, w) @@ -87,11 +96,11 @@ func (b *Blog) LogoutHandler(w http.ResponseWriter, r *http.Request) { // AccountHandler shows the account settings page. func (b *Blog) AccountHandler(w http.ResponseWriter, r *http.Request) { - if !b.LoggedIn(r) { + if !auth.LoggedIn(r) { b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!") return } - currentUser, err := b.CurrentUser(r) + currentUser, err := auth.CurrentUser(r) if err != nil { b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!!") return diff --git a/core/blog.go b/core/blog.go index 8028596..b5ceb85 100644 --- a/core/blog.go +++ b/core/blog.go @@ -15,6 +15,7 @@ import ( "github.com/gorilla/mux" "github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/markdown" + "github.com/kirsle/blog/core/internal/middleware/auth" "github.com/kirsle/blog/core/internal/models/comments" "github.com/kirsle/blog/core/internal/models/posts" "github.com/kirsle/blog/core/internal/models/settings" @@ -76,19 +77,10 @@ func (b *Blog) BlogRoutes(r *mux.Router) { loginRouter.HandleFunc("/blog/private", b.PrivatePosts) r.PathPrefix("/blog").Handler( negroni.New( - negroni.HandlerFunc(b.LoginRequired), + negroni.HandlerFunc(auth.LoginRequired(b.MustLogin)), negroni.Wrap(loginRouter), ), ) - - adminRouter := mux.NewRouter().PathPrefix("/admin").Subrouter().StrictSlash(false) - r.HandleFunc("/admin", b.AdminHandler) // so as to not be "/admin/" - adminRouter.HandleFunc("/settings", b.SettingsHandler) - adminRouter.PathPrefix("/").HandlerFunc(b.PageHandler) - r.PathPrefix("/admin").Handler(negroni.New( - negroni.HandlerFunc(b.LoginRequired), - negroni.Wrap(adminRouter), - )) } // RSSHandler renders an RSS feed from the blog. @@ -214,7 +206,7 @@ func (b *Blog) RecentPosts(r *http.Request, tag, privacy string) []posts.Post { } } else { // Exclude certain posts in generic index views. - if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !b.LoggedIn(r) { + if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !auth.LoggedIn(r) { continue } else if post.Privacy == DRAFT { continue @@ -370,7 +362,7 @@ func (b *Blog) BlogArchive(w http.ResponseWriter, r *http.Request) { byMonth := map[string]*Archive{} for _, post := range idx.Posts { // Exclude certain posts - if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !b.LoggedIn(r) { + if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !auth.LoggedIn(r) { continue } else if post.Privacy == DRAFT { continue @@ -416,8 +408,8 @@ func (b *Blog) viewPost(w http.ResponseWriter, r *http.Request, fragment string) // Handle post privacy. if post.Privacy == PRIVATE || post.Privacy == DRAFT { - if !b.LoggedIn(r) { - b.NotFound(w, r) + if !auth.LoggedIn(r) { + b.NotFound(w, r, "That post is not public.") return nil } } @@ -517,7 +509,7 @@ func (b *Blog) EditBlog(w http.ResponseWriter, r *http.Request) { if err := post.Validate(); err != nil { v.Error = err } else { - author, _ := b.CurrentUser(r) + author, _ := auth.CurrentUser(r) post.AuthorID = author.ID post.Updated = time.Now().UTC() diff --git a/core/comments.go b/core/comments.go index 40066e8..9e613d4 100644 --- a/core/comments.go +++ b/core/comments.go @@ -10,12 +10,14 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" - "github.com/gorilla/sessions" + gorilla "github.com/gorilla/sessions" "github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/markdown" + "github.com/kirsle/blog/core/internal/middleware/auth" "github.com/kirsle/blog/core/internal/models/comments" "github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/render" + "github.com/kirsle/blog/core/internal/sessions" ) // CommentRoutes attaches the comment routes to the app. @@ -37,7 +39,7 @@ type CommentMeta struct { } // RenderComments renders a comment form partial and returns the HTML. -func (b *Blog) RenderComments(session *sessions.Session, csrfToken, url, subject string, ids ...string) template.HTML { +func (b *Blog) RenderComments(session *gorilla.Session, csrfToken, url, subject string, ids ...string) template.HTML { id := strings.Join(ids, "-") // Load their cached name and email if they posted a comment before. @@ -142,7 +144,7 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) { return } v := NewVars() - currentUser, _ := b.CurrentUser(r) + currentUser, _ := auth.CurrentUser(r) editToken := b.GetEditToken(w, r) submit := r.FormValue("submit") @@ -193,21 +195,21 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) { } // Cache their name and email in their session. - session := b.Session(r) + session := sessions.Get(r) session.Values["c.name"] = c.Name session.Values["c.email"] = c.Email session.Save(r, w) // Previewing, deleting, or posting? switch submit { - case "preview", "delete": + case ActionPreview, ActionDelete: if !c.Editing && currentUser.IsAuthenticated { c.Name = currentUser.Name c.Email = currentUser.Email c.LoadAvatar() } c.HTML = template.HTML(markdown.RenderMarkdown(c.Body)) - case "post": + case ActionPost: if err := c.Validate(); err != nil { v.Error = err } else { @@ -251,7 +253,7 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) { v.Data["Thread"] = t v.Data["Comment"] = c v.Data["Editing"] = c.Editing - v.Data["Deleting"] = submit == "delete" + v.Data["Deleting"] = submit == ActionDelete b.RenderTemplate(w, r, "comments/index.gohtml", v) } @@ -295,7 +297,7 @@ func (b *Blog) QuickDeleteHandler(w http.ResponseWriter, r *http.Request) { thread := r.URL.Query().Get("t") token := r.URL.Query().Get("d") if thread == "" || token == "" { - b.BadRequest(w, r) + b.BadRequest(w, r, "Bad Request") return } @@ -315,7 +317,7 @@ func (b *Blog) QuickDeleteHandler(w http.ResponseWriter, r *http.Request) { // GetEditToken gets or generates an edit token from the user's session, which // allows a user to edit their comment for a short while after they post it. func (b *Blog) GetEditToken(w http.ResponseWriter, r *http.Request) string { - session := b.Session(r) + session := sessions.Get(r) if token, ok := session.Values["c.token"].(string); ok && len(token) > 0 { return token } diff --git a/core/constants.go b/core/constants.go index ba1bcb0..a2bd76e 100644 --- a/core/constants.go +++ b/core/constants.go @@ -19,3 +19,11 @@ const ( MARKDOWN ContentType = "markdown" HTML ContentType = "html" ) + +// Common form actions. +const ( + ActionSave = "save" + ActionDelete = "delete" + ActionPreview = "preview" + ActionPost = "post" +) diff --git a/core/core.go b/core/core.go index 60da7e7..3324f16 100644 --- a/core/core.go +++ b/core/core.go @@ -7,14 +7,16 @@ import ( "path/filepath" "github.com/gorilla/mux" - "github.com/gorilla/sessions" "github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/markdown" + "github.com/kirsle/blog/core/internal/middleware" + "github.com/kirsle/blog/core/internal/middleware/auth" "github.com/kirsle/blog/core/internal/models/comments" "github.com/kirsle/blog/core/internal/models/posts" "github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/render" + "github.com/kirsle/blog/core/internal/sessions" "github.com/kirsle/blog/jsondb" "github.com/kirsle/blog/jsondb/caches" "github.com/kirsle/blog/jsondb/caches/null" @@ -36,9 +38,8 @@ type Blog struct { Cache caches.Cacher // Web app objects. - n *negroni.Negroni // Negroni middleware manager - r *mux.Router // Router - store sessions.Store + n *negroni.Negroni // Negroni middleware manager + r *mux.Router // Router } // New initializes the Blog application. @@ -73,7 +74,7 @@ func (b *Blog) Configure() { render.DocumentRoot = &b.DocumentRoot // Initialize the session cookie store. - b.store = sessions.NewCookieStore([]byte(config.Security.SecretKey)) + sessions.SetSecretKey([]byte(config.Security.SecretKey)) users.HashCost = config.Security.HashCost // Initialize the rest of the models. @@ -120,9 +121,9 @@ func (b *Blog) SetupHTTP() { n := negroni.New( negroni.NewRecovery(), negroni.NewLogger(), - negroni.HandlerFunc(b.SessionLoader), - negroni.HandlerFunc(b.CSRFMiddleware), - negroni.HandlerFunc(b.AuthMiddleware), + negroni.HandlerFunc(sessions.Middleware), + negroni.HandlerFunc(middleware.CSRF(b.Forbidden)), + negroni.HandlerFunc(auth.Middleware), ) n.UseHandler(r) diff --git a/core/initial-setup.go b/core/initial-setup.go index a99e656..6857365 100644 --- a/core/initial-setup.go +++ b/core/initial-setup.go @@ -3,12 +3,12 @@ package core import ( "net/http" - "github.com/gorilla/sessions" "github.com/kirsle/blog/core/internal/forms" "github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/render" + "github.com/kirsle/blog/core/internal/sessions" ) // SetupHandler is the initial blog setup route. @@ -41,7 +41,7 @@ func (b *Blog) SetupHandler(w http.ResponseWriter, r *http.Request) { s.Save() // Re-initialize the cookie store with the new secret key. - b.store = sessions.NewCookieStore([]byte(s.Security.SecretKey)) + sessions.SetSecretKey([]byte(s.Security.SecretKey)) log.Info("Creating admin account %s", form.Username) user := &users.User{ diff --git a/core/internal/middleware/auth/auth.go b/core/internal/middleware/auth/auth.go new file mode 100644 index 0000000..be3d87f --- /dev/null +++ b/core/internal/middleware/auth/auth.go @@ -0,0 +1,67 @@ +package auth + +import ( + "context" + "errors" + "net/http" + + "github.com/kirsle/blog/core/internal/log" + "github.com/kirsle/blog/core/internal/models/users" + "github.com/kirsle/blog/core/internal/sessions" + "github.com/kirsle/blog/core/internal/types" + "github.com/urfave/negroni" +) + +// CurrentUser returns the current user's object. +func CurrentUser(r *http.Request) (*users.User, error) { + session := sessions.Get(r) + if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn { + id := session.Values["user-id"].(int) + u, err := users.LoadReadonly(id) + u.IsAuthenticated = true + return u, err + } + + return &users.User{ + Admin: false, + }, errors.New("not authenticated") +} + +// LoggedIn returns whether the current user is logged in to an account. +func LoggedIn(r *http.Request) bool { + session := sessions.Get(r) + if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn { + return true + } + return false +} + +// LoginRequired is a middleware that requires a logged-in user. +func LoginRequired(onError http.HandlerFunc) negroni.HandlerFunc { + middleware := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + ctx := r.Context() + if user, ok := ctx.Value(types.UserKey).(*users.User); ok { + if user.ID > 0 { + next(w, r) + return + } + } + + log.Info("Redirect away!") + onError(w, r) + } + + return middleware +} + +// Middleware loads the user's authentication state from their session cookie. +func Middleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + u, err := CurrentUser(r) + if err != nil { + next(w, r) + return + } + + ctx := context.WithValue(r.Context(), types.UserKey, u) + next(w, r.WithContext(ctx)) +} diff --git a/core/internal/middleware/csrf.go b/core/internal/middleware/csrf.go new file mode 100644 index 0000000..37d68b3 --- /dev/null +++ b/core/internal/middleware/csrf.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "net/http" + + "github.com/google/uuid" + gorilla "github.com/gorilla/sessions" + "github.com/kirsle/blog/core/internal/log" + "github.com/kirsle/blog/core/internal/sessions" + "github.com/urfave/negroni" +) + +// CSRF is a middleware generator that enforces CSRF tokens on all POST requests. +func CSRF(onError func(http.ResponseWriter, *http.Request, string)) negroni.HandlerFunc { + middleware := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if r.Method == "POST" { + session := sessions.Get(r) + token := GenerateCSRFToken(w, r, session) + if token != r.FormValue("_csrf") { + log.Error("CSRF Mismatch: expected %s, got %s", r.FormValue("_csrf"), token) + onError(w, r, "Failed to validate CSRF token. Please try your request again.") + return + } + } + next(w, r) + } + + return middleware +} + +// ExampleCSRF shows how to use the CSRF handler. +func ExampleCSRF() { + // Your error handling for CSRF failures. + onError := func(w http.ResponseWriter, r *http.Request, message string) { + w.Write([]byte("CSRF Error: " + message)) + } + + // Attach the middleware. + _ = negroni.New( + negroni.NewRecovery(), + negroni.NewLogger(), + negroni.HandlerFunc(CSRF(onError)), + ) +} + +// GenerateCSRFToken generates a CSRF token for the user and puts it in their session. +func GenerateCSRFToken(w http.ResponseWriter, r *http.Request, session *gorilla.Session) string { + token, ok := session.Values["csrf"].(string) + if !ok { + token := uuid.New() + session.Values["csrf"] = token.String() + session.Save(r, w) + } + return token +} diff --git a/core/internal/responses/responses.go b/core/internal/responses/responses.go deleted file mode 100644 index 1c0e193..0000000 --- a/core/internal/responses/responses.go +++ /dev/null @@ -1,66 +0,0 @@ -package responses - -import ( - "net/http" - - "github.com/kirsle/blog/core/internal/log" - "github.com/kirsle/blog/core/internal/render" -) - -// Redirect sends an HTTP redirect response. -func Redirect(w http.ResponseWriter, location string) { - w.Header().Set("Location", location) - w.WriteHeader(http.StatusFound) -} - -// NotFound sends a 404 response. -func NotFound(w http.ResponseWriter, r *http.Request, message ...string) { - if len(message) == 0 { - message = []string{"The page you were looking for was not found."} - } - - w.WriteHeader(http.StatusNotFound) - err := render.RenderTemplate(w, r, ".errors/404", &render.Vars{ - Message: message[0], - }) - if err != nil { - log.Error(err.Error()) - w.Write([]byte("Unrecoverable template error for NotFound()")) - } -} - -// Forbidden sends an HTTP 403 Forbidden response. -func Forbidden(w http.ResponseWriter, r *http.Request, message ...string) { - w.WriteHeader(http.StatusForbidden) - err := render.RenderTemplate(w, r, ".errors/403", &render.Vars{ - Message: message[0], - }) - if err != nil { - log.Error(err.Error()) - w.Write([]byte("Unrecoverable template error for Forbidden()")) - } -} - -// Error sends an HTTP 500 Internal Server Error response. -func Error(w http.ResponseWriter, r *http.Request, message ...string) { - w.WriteHeader(http.StatusInternalServerError) - err := render.RenderTemplate(w, r, ".errors/500", &render.Vars{ - Message: message[0], - }) - if err != nil { - log.Error(err.Error()) - w.Write([]byte("Unrecoverable template error for Error()")) - } -} - -// BadRequest sends an HTTP 400 Bad Request. -func BadRequest(w http.ResponseWriter, r *http.Request, message ...string) { - w.WriteHeader(http.StatusBadRequest) - err := render.RenderTemplate(w, r, ".errors/400", &render.Vars{ - Message: message[0], - }) - if err != nil { - log.Error(err.Error()) - w.Write([]byte("Unrecoverable template error for BadRequest()")) - } -} diff --git a/core/internal/sessions/sessions.go b/core/internal/sessions/sessions.go new file mode 100644 index 0000000..3c9ec3e --- /dev/null +++ b/core/internal/sessions/sessions.go @@ -0,0 +1,56 @@ +package sessions + +import ( + "context" + "net/http" + "time" + + "github.com/gorilla/sessions" + "github.com/kirsle/blog/core/internal/log" + "github.com/kirsle/blog/core/internal/types" +) + +// Store holds your cookie store information. +var Store sessions.Store + +// SetSecretKey initializes a session cookie store with the secret key. +func SetSecretKey(keyPairs ...[]byte) { + Store = sessions.NewCookieStore(keyPairs...) +} + +// Middleware gets the Gorilla session store and makes it available on the +// Request context. +// +// Middleware is the first custom middleware applied, so it takes the current +// datetime to make available later in the request and stores it on the request +// context. +func Middleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + // Store the current datetime on the request context. + ctx := context.WithValue(r.Context(), types.StartTimeKey, time.Now()) + + // Get the Gorilla session and make it available in the request context. + session, _ := Store.Get(r, "session") + ctx = context.WithValue(ctx, types.SessionKey, session) + + next(w, r.WithContext(ctx)) +} + +// Get returns the current request's session. +func Get(r *http.Request) *sessions.Session { + if r == nil { + panic("Session(*http.Request) with a nil argument!?") + } + + ctx := r.Context() + if session, ok := ctx.Value(types.SessionKey).(*sessions.Session); ok { + return session + } + + // If the session wasn't on the request, it means I broke something. + log.Error( + "Session(): didn't find session in request context! Getting it " + + "from the session store instead.", + ) + session, _ := Store.Get(r, "session") + return session +} diff --git a/core/internal/types/context.go b/core/internal/types/context.go new file mode 100644 index 0000000..67bd6a7 --- /dev/null +++ b/core/internal/types/context.go @@ -0,0 +1,11 @@ +package types + +// Key is an integer enum for context.Context keys. +type Key int + +// Key definitions. +const ( + SessionKey Key = iota // The request's cookie session object. + UserKey // The request's user data for logged-in users. + StartTimeKey // HTTP request start time. +) diff --git a/core/middleware.go b/core/middleware.go deleted file mode 100644 index 5c86749..0000000 --- a/core/middleware.go +++ /dev/null @@ -1,133 +0,0 @@ -package core - -import ( - "context" - "errors" - "net/http" - "time" - - "github.com/google/uuid" - "github.com/gorilla/sessions" - "github.com/kirsle/blog/core/internal/log" - "github.com/kirsle/blog/core/internal/models/users" -) - -type key int - -const ( - sessionKey key = iota - userKey - requestTimeKey -) - -// SessionLoader gets the Gorilla session store and makes it available on the -// Request context. -// -// SessionLoader is the first custom middleware applied, so it takes the current -// datetime to make available later in the request and stores it on the request -// context. -func (b *Blog) SessionLoader(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - // Store the current datetime on the request context. - ctx := context.WithValue(r.Context(), requestTimeKey, time.Now()) - - // Get the Gorilla session and make it available in the request context. - session, _ := b.store.Get(r, "session") - ctx = context.WithValue(ctx, sessionKey, session) - - next(w, r.WithContext(ctx)) -} - -// Session returns the current request's session. -func (b *Blog) Session(r *http.Request) *sessions.Session { - if r == nil { - panic("Session(*http.Request) with a nil argument!?") - } - - ctx := r.Context() - if session, ok := ctx.Value(sessionKey).(*sessions.Session); ok { - return session - } - - log.Error( - "Session(): didn't find session in request context! Getting it " + - "from the session store instead.", - ) - session, _ := b.store.Get(r, "session") - return session -} - -// CSRFMiddleware enforces CSRF tokens on all POST requests. -func (b *Blog) CSRFMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - if r.Method == "POST" { - session := b.Session(r) - token := b.GenerateCSRFToken(w, r, session) - if token != r.FormValue("_csrf") { - log.Error("CSRF Mismatch: expected %s, got %s", r.FormValue("_csrf"), token) - b.Forbidden(w, r, "Failed to validate CSRF token. Please try your request again.") - return - } - } - - next(w, r) -} - -// GenerateCSRFToken generates a CSRF token for the user and puts it in their session. -func (b *Blog) GenerateCSRFToken(w http.ResponseWriter, r *http.Request, session *sessions.Session) string { - token, ok := session.Values["csrf"].(string) - if !ok { - token := uuid.New() - session.Values["csrf"] = token.String() - session.Save(r, w) - } - return token -} - -// CurrentUser returns the current user's object. -func (b *Blog) CurrentUser(r *http.Request) (*users.User, error) { - session := b.Session(r) - if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn { - id := session.Values["user-id"].(int) - u, err := users.LoadReadonly(id) - u.IsAuthenticated = true - return u, err - } - - return &users.User{ - Admin: false, - }, errors.New("not authenticated") -} - -// LoggedIn returns whether the current user is logged in to an account. -func (b *Blog) LoggedIn(r *http.Request) bool { - session := b.Session(r) - if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn { - return true - } - return false -} - -// AuthMiddleware loads the user's authentication state. -func (b *Blog) AuthMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - u, err := b.CurrentUser(r) - if err != nil { - next(w, r) - return - } - - ctx := context.WithValue(r.Context(), userKey, u) - next(w, r.WithContext(ctx)) -} - -// LoginRequired is a middleware that requires a logged-in user. -func (b *Blog) LoginRequired(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - ctx := r.Context() - if user, ok := ctx.Value(userKey).(*users.User); ok { - if user.ID > 0 { - next(w, r) - return - } - } - - log.Info("Redirect away!") - b.Redirect(w, "/login?next="+r.URL.Path) -} diff --git a/core/pages.go b/core/pages.go index 1b7444b..a6f20b2 100644 --- a/core/pages.go +++ b/core/pages.go @@ -23,7 +23,7 @@ func (b *Blog) PageHandler(w http.ResponseWriter, r *http.Request) { // Restrict special paths. if strings.HasPrefix(strings.ToLower(path), "/.") { - b.Forbidden(w, r) + b.Forbidden(w, r, "Forbidden") return } diff --git a/core/responses.go b/core/responses.go index 02f7720..79e85b2 100644 --- a/core/responses.go +++ b/core/responses.go @@ -6,11 +6,12 @@ import ( "github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/render" + "github.com/kirsle/blog/core/internal/sessions" ) // Flash adds a flash message to the user's session. func (b *Blog) Flash(w http.ResponseWriter, r *http.Request, message string, args ...interface{}) { - session := b.Session(r) + session := sessions.Get(r) session.AddFlash(fmt.Sprintf(message, args...)) session.Save(r, w) } @@ -34,14 +35,14 @@ func (b *Blog) Redirect(w http.ResponseWriter, location string) { } // NotFound sends a 404 response. -func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message ...string) { - if len(message) == 0 { - message = []string{"The page you were looking for was not found."} +func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message string) { + if message == "" { + message = "The page you were looking for was not found." } w.WriteHeader(http.StatusNotFound) err := b.RenderTemplate(w, r, ".errors/404", render.Vars{ - Message: message[0], + Message: message, }) if err != nil { log.Error(err.Error()) @@ -50,10 +51,10 @@ func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message ...strin } // Forbidden sends an HTTP 403 Forbidden response. -func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message ...string) { +func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message string) { w.WriteHeader(http.StatusForbidden) err := b.RenderTemplate(w, r, ".errors/403", render.Vars{ - Message: message[0], + Message: message, }) if err != nil { log.Error(err.Error()) @@ -62,10 +63,10 @@ func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message ...stri } // Error sends an HTTP 500 Internal Server Error response. -func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message ...string) { +func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message string) { w.WriteHeader(http.StatusInternalServerError) err := b.RenderTemplate(w, r, ".errors/500", render.Vars{ - Message: message[0], + Message: message, }) if err != nil { log.Error(err.Error()) @@ -74,10 +75,10 @@ func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message ...string) } // BadRequest sends an HTTP 400 Bad Request. -func (b *Blog) BadRequest(w http.ResponseWriter, r *http.Request, message ...string) { +func (b *Blog) BadRequest(w http.ResponseWriter, r *http.Request, message string) { w.WriteHeader(http.StatusBadRequest) err := b.RenderTemplate(w, r, ".errors/400", render.Vars{ - Message: message[0], + Message: message, }) if err != nil { log.Error(err.Error()) diff --git a/core/templates.go b/core/templates.go index 3c36eea..b8759ec 100644 --- a/core/templates.go +++ b/core/templates.go @@ -8,9 +8,13 @@ import ( "time" "github.com/kirsle/blog/core/internal/forms" + "github.com/kirsle/blog/core/internal/middleware" + "github.com/kirsle/blog/core/internal/middleware/auth" "github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/render" + "github.com/kirsle/blog/core/internal/sessions" + "github.com/kirsle/blog/core/internal/types" ) // Vars is an interface to implement by the templates to pass their own custom @@ -66,11 +70,11 @@ func (b *Blog) LoadDefaults(v render.Vars, r *http.Request) render.Vars { v.SetupNeeded = true } v.Request = r - v.RequestTime = r.Context().Value(requestTimeKey).(time.Time) + v.RequestTime = r.Context().Value(types.StartTimeKey).(time.Time) v.Title = s.Site.Title v.Path = r.URL.Path - user, err := b.CurrentUser(r) + user, err := auth.CurrentUser(r) v.CurrentUser = user v.LoggedIn = err == nil @@ -109,7 +113,7 @@ func (b *Blog) RenderTemplate(w http.ResponseWriter, r *http.Request, path strin vars = b.LoadDefaults(vars, r) // Add any flashed messages from the endpoint controllers. - session := b.Session(r) + session := sessions.Get(r) if flashes := session.Flashes(); len(flashes) > 0 { for _, flash := range flashes { _ = flash @@ -119,7 +123,7 @@ func (b *Blog) RenderTemplate(w http.ResponseWriter, r *http.Request, path strin } vars.RequestDuration = time.Now().Sub(vars.RequestTime) - vars.CSRF = b.GenerateCSRFToken(w, r, session) + vars.CSRF = middleware.GenerateCSRFToken(w, r, session) vars.Editable = !strings.HasPrefix(path, "admin/") return render.Template(w, path, render.Config{ @@ -140,8 +144,8 @@ func (b *Blog) TemplateFuncs(w http.ResponseWriter, r *http.Request, inject map[ return template.HTML("[RenderComments Error: need both http.ResponseWriter and http.Request]") } - session := b.Session(r) - csrf := b.GenerateCSRFToken(w, r, session) + session := sessions.Get(r) + csrf := middleware.GenerateCSRFToken(w, r, session) return b.RenderComments(session, csrf, r.URL.Path, subject, ids...) }, }