Eviscerate the last of the middleware into sub-packages

This commit is contained in:
Noah 2018-02-10 11:14:42 -08:00
parent e393b1880f
commit 60ccaf7b35
16 changed files with 266 additions and 258 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/kirsle/blog/core/internal/forms" "github.com/kirsle/blog/core/internal/forms"
"github.com/kirsle/blog/core/internal/middleware/auth"
"github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/settings"
"github.com/kirsle/blog/core/internal/render" "github.com/kirsle/blog/core/internal/render"
"github.com/urfave/negroni" "github.com/urfave/negroni"
@ -24,7 +25,7 @@ func (b *Blog) AdminRoutes(r *mux.Router) {
adminRouter.HandleFunc("/editor", b.EditorHandler) adminRouter.HandleFunc("/editor", b.EditorHandler)
// r.HandleFunc("/admin", b.AdminHandler) // r.HandleFunc("/admin", b.AdminHandler)
r.PathPrefix("/admin").Handler(negroni.New( r.PathPrefix("/admin").Handler(negroni.New(
negroni.HandlerFunc(b.LoginRequired), negroni.HandlerFunc(auth.LoginRequired(b.MustLogin)),
negroni.Wrap(adminRouter), negroni.Wrap(adminRouter),
)) ))
} }
@ -48,8 +49,8 @@ func (b *Blog) EditorHandler(w http.ResponseWriter, r *http.Request) {
var ( var (
fp string fp string
fromCore = r.FormValue("from") == "core" fromCore = r.FormValue("from") == "core"
saving = r.FormValue("action") == "save" saving = r.FormValue("action") == ActionSave
deleting = r.FormValue("action") == "delete" deleting = r.FormValue("action") == ActionDelete
body = []byte{} body = []byte{}
) )

View File

@ -7,7 +7,9 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/kirsle/blog/core/internal/forms" "github.com/kirsle/blog/core/internal/forms"
"github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/middleware/auth"
"github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/models/users"
"github.com/kirsle/blog/core/internal/sessions"
) )
// AuthRoutes attaches the auth routes to the app. // AuthRoutes attaches the auth routes to the app.
@ -17,9 +19,16 @@ func (b *Blog) AuthRoutes(r *mux.Router) {
r.HandleFunc("/account", b.AccountHandler) r.HandleFunc("/account", b.AccountHandler)
} }
// MustLogin handles errors from the LoginRequired middleware by redirecting
// the user to the login page.
func (b *Blog) MustLogin(w http.ResponseWriter, r *http.Request) {
log.Info("MustLogin for %s", r.URL.Path)
b.Redirect(w, "/login?next="+r.URL.Path)
}
// Login logs the browser in as the given user. // Login logs the browser in as the given user.
func (b *Blog) Login(w http.ResponseWriter, r *http.Request, u *users.User) error { func (b *Blog) Login(w http.ResponseWriter, r *http.Request, u *users.User) error {
session, err := b.store.Get(r, "session") // TODO session name session, err := sessions.Store.Get(r, "session") // TODO session name
if err != nil { if err != nil {
return err return err
} }
@ -78,7 +87,7 @@ func (b *Blog) LoginHandler(w http.ResponseWriter, r *http.Request) {
// LogoutHandler logs the user out and redirects to the home page. // LogoutHandler logs the user out and redirects to the home page.
func (b *Blog) LogoutHandler(w http.ResponseWriter, r *http.Request) { func (b *Blog) LogoutHandler(w http.ResponseWriter, r *http.Request) {
session, _ := b.store.Get(r, "session") session, _ := sessions.Store.Get(r, "session")
delete(session.Values, "logged-in") delete(session.Values, "logged-in")
delete(session.Values, "user-id") delete(session.Values, "user-id")
session.Save(r, w) session.Save(r, w)
@ -87,11 +96,11 @@ func (b *Blog) LogoutHandler(w http.ResponseWriter, r *http.Request) {
// AccountHandler shows the account settings page. // AccountHandler shows the account settings page.
func (b *Blog) AccountHandler(w http.ResponseWriter, r *http.Request) { func (b *Blog) AccountHandler(w http.ResponseWriter, r *http.Request) {
if !b.LoggedIn(r) { if !auth.LoggedIn(r) {
b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!") b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!")
return return
} }
currentUser, err := b.CurrentUser(r) currentUser, err := auth.CurrentUser(r)
if err != nil { if err != nil {
b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!!") b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!!")
return return

View File

@ -15,6 +15,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/markdown" "github.com/kirsle/blog/core/internal/markdown"
"github.com/kirsle/blog/core/internal/middleware/auth"
"github.com/kirsle/blog/core/internal/models/comments" "github.com/kirsle/blog/core/internal/models/comments"
"github.com/kirsle/blog/core/internal/models/posts" "github.com/kirsle/blog/core/internal/models/posts"
"github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/settings"
@ -76,19 +77,10 @@ func (b *Blog) BlogRoutes(r *mux.Router) {
loginRouter.HandleFunc("/blog/private", b.PrivatePosts) loginRouter.HandleFunc("/blog/private", b.PrivatePosts)
r.PathPrefix("/blog").Handler( r.PathPrefix("/blog").Handler(
negroni.New( negroni.New(
negroni.HandlerFunc(b.LoginRequired), negroni.HandlerFunc(auth.LoginRequired(b.MustLogin)),
negroni.Wrap(loginRouter), negroni.Wrap(loginRouter),
), ),
) )
adminRouter := mux.NewRouter().PathPrefix("/admin").Subrouter().StrictSlash(false)
r.HandleFunc("/admin", b.AdminHandler) // so as to not be "/admin/"
adminRouter.HandleFunc("/settings", b.SettingsHandler)
adminRouter.PathPrefix("/").HandlerFunc(b.PageHandler)
r.PathPrefix("/admin").Handler(negroni.New(
negroni.HandlerFunc(b.LoginRequired),
negroni.Wrap(adminRouter),
))
} }
// RSSHandler renders an RSS feed from the blog. // RSSHandler renders an RSS feed from the blog.
@ -214,7 +206,7 @@ func (b *Blog) RecentPosts(r *http.Request, tag, privacy string) []posts.Post {
} }
} else { } else {
// Exclude certain posts in generic index views. // Exclude certain posts in generic index views.
if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !b.LoggedIn(r) { if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !auth.LoggedIn(r) {
continue continue
} else if post.Privacy == DRAFT { } else if post.Privacy == DRAFT {
continue continue
@ -370,7 +362,7 @@ func (b *Blog) BlogArchive(w http.ResponseWriter, r *http.Request) {
byMonth := map[string]*Archive{} byMonth := map[string]*Archive{}
for _, post := range idx.Posts { for _, post := range idx.Posts {
// Exclude certain posts // Exclude certain posts
if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !b.LoggedIn(r) { if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !auth.LoggedIn(r) {
continue continue
} else if post.Privacy == DRAFT { } else if post.Privacy == DRAFT {
continue continue
@ -416,8 +408,8 @@ func (b *Blog) viewPost(w http.ResponseWriter, r *http.Request, fragment string)
// Handle post privacy. // Handle post privacy.
if post.Privacy == PRIVATE || post.Privacy == DRAFT { if post.Privacy == PRIVATE || post.Privacy == DRAFT {
if !b.LoggedIn(r) { if !auth.LoggedIn(r) {
b.NotFound(w, r) b.NotFound(w, r, "That post is not public.")
return nil return nil
} }
} }
@ -517,7 +509,7 @@ func (b *Blog) EditBlog(w http.ResponseWriter, r *http.Request) {
if err := post.Validate(); err != nil { if err := post.Validate(); err != nil {
v.Error = err v.Error = err
} else { } else {
author, _ := b.CurrentUser(r) author, _ := auth.CurrentUser(r)
post.AuthorID = author.ID post.AuthorID = author.ID
post.Updated = time.Now().UTC() post.Updated = time.Now().UTC()

View File

@ -10,12 +10,14 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/sessions" gorilla "github.com/gorilla/sessions"
"github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/markdown" "github.com/kirsle/blog/core/internal/markdown"
"github.com/kirsle/blog/core/internal/middleware/auth"
"github.com/kirsle/blog/core/internal/models/comments" "github.com/kirsle/blog/core/internal/models/comments"
"github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/models/users"
"github.com/kirsle/blog/core/internal/render" "github.com/kirsle/blog/core/internal/render"
"github.com/kirsle/blog/core/internal/sessions"
) )
// CommentRoutes attaches the comment routes to the app. // CommentRoutes attaches the comment routes to the app.
@ -37,7 +39,7 @@ type CommentMeta struct {
} }
// RenderComments renders a comment form partial and returns the HTML. // RenderComments renders a comment form partial and returns the HTML.
func (b *Blog) RenderComments(session *sessions.Session, csrfToken, url, subject string, ids ...string) template.HTML { func (b *Blog) RenderComments(session *gorilla.Session, csrfToken, url, subject string, ids ...string) template.HTML {
id := strings.Join(ids, "-") id := strings.Join(ids, "-")
// Load their cached name and email if they posted a comment before. // Load their cached name and email if they posted a comment before.
@ -142,7 +144,7 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
v := NewVars() v := NewVars()
currentUser, _ := b.CurrentUser(r) currentUser, _ := auth.CurrentUser(r)
editToken := b.GetEditToken(w, r) editToken := b.GetEditToken(w, r)
submit := r.FormValue("submit") submit := r.FormValue("submit")
@ -193,21 +195,21 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) {
} }
// Cache their name and email in their session. // Cache their name and email in their session.
session := b.Session(r) session := sessions.Get(r)
session.Values["c.name"] = c.Name session.Values["c.name"] = c.Name
session.Values["c.email"] = c.Email session.Values["c.email"] = c.Email
session.Save(r, w) session.Save(r, w)
// Previewing, deleting, or posting? // Previewing, deleting, or posting?
switch submit { switch submit {
case "preview", "delete": case ActionPreview, ActionDelete:
if !c.Editing && currentUser.IsAuthenticated { if !c.Editing && currentUser.IsAuthenticated {
c.Name = currentUser.Name c.Name = currentUser.Name
c.Email = currentUser.Email c.Email = currentUser.Email
c.LoadAvatar() c.LoadAvatar()
} }
c.HTML = template.HTML(markdown.RenderMarkdown(c.Body)) c.HTML = template.HTML(markdown.RenderMarkdown(c.Body))
case "post": case ActionPost:
if err := c.Validate(); err != nil { if err := c.Validate(); err != nil {
v.Error = err v.Error = err
} else { } else {
@ -251,7 +253,7 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) {
v.Data["Thread"] = t v.Data["Thread"] = t
v.Data["Comment"] = c v.Data["Comment"] = c
v.Data["Editing"] = c.Editing v.Data["Editing"] = c.Editing
v.Data["Deleting"] = submit == "delete" v.Data["Deleting"] = submit == ActionDelete
b.RenderTemplate(w, r, "comments/index.gohtml", v) b.RenderTemplate(w, r, "comments/index.gohtml", v)
} }
@ -295,7 +297,7 @@ func (b *Blog) QuickDeleteHandler(w http.ResponseWriter, r *http.Request) {
thread := r.URL.Query().Get("t") thread := r.URL.Query().Get("t")
token := r.URL.Query().Get("d") token := r.URL.Query().Get("d")
if thread == "" || token == "" { if thread == "" || token == "" {
b.BadRequest(w, r) b.BadRequest(w, r, "Bad Request")
return return
} }
@ -315,7 +317,7 @@ func (b *Blog) QuickDeleteHandler(w http.ResponseWriter, r *http.Request) {
// GetEditToken gets or generates an edit token from the user's session, which // GetEditToken gets or generates an edit token from the user's session, which
// allows a user to edit their comment for a short while after they post it. // allows a user to edit their comment for a short while after they post it.
func (b *Blog) GetEditToken(w http.ResponseWriter, r *http.Request) string { func (b *Blog) GetEditToken(w http.ResponseWriter, r *http.Request) string {
session := b.Session(r) session := sessions.Get(r)
if token, ok := session.Values["c.token"].(string); ok && len(token) > 0 { if token, ok := session.Values["c.token"].(string); ok && len(token) > 0 {
return token return token
} }

View File

@ -19,3 +19,11 @@ const (
MARKDOWN ContentType = "markdown" MARKDOWN ContentType = "markdown"
HTML ContentType = "html" HTML ContentType = "html"
) )
// Common form actions.
const (
ActionSave = "save"
ActionDelete = "delete"
ActionPreview = "preview"
ActionPost = "post"
)

View File

@ -7,14 +7,16 @@ import (
"path/filepath" "path/filepath"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/markdown" "github.com/kirsle/blog/core/internal/markdown"
"github.com/kirsle/blog/core/internal/middleware"
"github.com/kirsle/blog/core/internal/middleware/auth"
"github.com/kirsle/blog/core/internal/models/comments" "github.com/kirsle/blog/core/internal/models/comments"
"github.com/kirsle/blog/core/internal/models/posts" "github.com/kirsle/blog/core/internal/models/posts"
"github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/settings"
"github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/models/users"
"github.com/kirsle/blog/core/internal/render" "github.com/kirsle/blog/core/internal/render"
"github.com/kirsle/blog/core/internal/sessions"
"github.com/kirsle/blog/jsondb" "github.com/kirsle/blog/jsondb"
"github.com/kirsle/blog/jsondb/caches" "github.com/kirsle/blog/jsondb/caches"
"github.com/kirsle/blog/jsondb/caches/null" "github.com/kirsle/blog/jsondb/caches/null"
@ -36,9 +38,8 @@ type Blog struct {
Cache caches.Cacher Cache caches.Cacher
// Web app objects. // Web app objects.
n *negroni.Negroni // Negroni middleware manager n *negroni.Negroni // Negroni middleware manager
r *mux.Router // Router r *mux.Router // Router
store sessions.Store
} }
// New initializes the Blog application. // New initializes the Blog application.
@ -73,7 +74,7 @@ func (b *Blog) Configure() {
render.DocumentRoot = &b.DocumentRoot render.DocumentRoot = &b.DocumentRoot
// Initialize the session cookie store. // Initialize the session cookie store.
b.store = sessions.NewCookieStore([]byte(config.Security.SecretKey)) sessions.SetSecretKey([]byte(config.Security.SecretKey))
users.HashCost = config.Security.HashCost users.HashCost = config.Security.HashCost
// Initialize the rest of the models. // Initialize the rest of the models.
@ -120,9 +121,9 @@ func (b *Blog) SetupHTTP() {
n := negroni.New( n := negroni.New(
negroni.NewRecovery(), negroni.NewRecovery(),
negroni.NewLogger(), negroni.NewLogger(),
negroni.HandlerFunc(b.SessionLoader), negroni.HandlerFunc(sessions.Middleware),
negroni.HandlerFunc(b.CSRFMiddleware), negroni.HandlerFunc(middleware.CSRF(b.Forbidden)),
negroni.HandlerFunc(b.AuthMiddleware), negroni.HandlerFunc(auth.Middleware),
) )
n.UseHandler(r) n.UseHandler(r)

View File

@ -3,12 +3,12 @@ package core
import ( import (
"net/http" "net/http"
"github.com/gorilla/sessions"
"github.com/kirsle/blog/core/internal/forms" "github.com/kirsle/blog/core/internal/forms"
"github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/settings"
"github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/models/users"
"github.com/kirsle/blog/core/internal/render" "github.com/kirsle/blog/core/internal/render"
"github.com/kirsle/blog/core/internal/sessions"
) )
// SetupHandler is the initial blog setup route. // SetupHandler is the initial blog setup route.
@ -41,7 +41,7 @@ func (b *Blog) SetupHandler(w http.ResponseWriter, r *http.Request) {
s.Save() s.Save()
// Re-initialize the cookie store with the new secret key. // Re-initialize the cookie store with the new secret key.
b.store = sessions.NewCookieStore([]byte(s.Security.SecretKey)) sessions.SetSecretKey([]byte(s.Security.SecretKey))
log.Info("Creating admin account %s", form.Username) log.Info("Creating admin account %s", form.Username)
user := &users.User{ user := &users.User{

View File

@ -0,0 +1,67 @@
package auth
import (
"context"
"errors"
"net/http"
"github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/models/users"
"github.com/kirsle/blog/core/internal/sessions"
"github.com/kirsle/blog/core/internal/types"
"github.com/urfave/negroni"
)
// CurrentUser returns the current user's object.
func CurrentUser(r *http.Request) (*users.User, error) {
session := sessions.Get(r)
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
id := session.Values["user-id"].(int)
u, err := users.LoadReadonly(id)
u.IsAuthenticated = true
return u, err
}
return &users.User{
Admin: false,
}, errors.New("not authenticated")
}
// LoggedIn returns whether the current user is logged in to an account.
func LoggedIn(r *http.Request) bool {
session := sessions.Get(r)
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
return true
}
return false
}
// LoginRequired is a middleware that requires a logged-in user.
func LoginRequired(onError http.HandlerFunc) negroni.HandlerFunc {
middleware := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
ctx := r.Context()
if user, ok := ctx.Value(types.UserKey).(*users.User); ok {
if user.ID > 0 {
next(w, r)
return
}
}
log.Info("Redirect away!")
onError(w, r)
}
return middleware
}
// Middleware loads the user's authentication state from their session cookie.
func Middleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
u, err := CurrentUser(r)
if err != nil {
next(w, r)
return
}
ctx := context.WithValue(r.Context(), types.UserKey, u)
next(w, r.WithContext(ctx))
}

View File

@ -0,0 +1,55 @@
package middleware
import (
"net/http"
"github.com/google/uuid"
gorilla "github.com/gorilla/sessions"
"github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/sessions"
"github.com/urfave/negroni"
)
// CSRF is a middleware generator that enforces CSRF tokens on all POST requests.
func CSRF(onError func(http.ResponseWriter, *http.Request, string)) negroni.HandlerFunc {
middleware := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if r.Method == "POST" {
session := sessions.Get(r)
token := GenerateCSRFToken(w, r, session)
if token != r.FormValue("_csrf") {
log.Error("CSRF Mismatch: expected %s, got %s", r.FormValue("_csrf"), token)
onError(w, r, "Failed to validate CSRF token. Please try your request again.")
return
}
}
next(w, r)
}
return middleware
}
// ExampleCSRF shows how to use the CSRF handler.
func ExampleCSRF() {
// Your error handling for CSRF failures.
onError := func(w http.ResponseWriter, r *http.Request, message string) {
w.Write([]byte("CSRF Error: " + message))
}
// Attach the middleware.
_ = negroni.New(
negroni.NewRecovery(),
negroni.NewLogger(),
negroni.HandlerFunc(CSRF(onError)),
)
}
// GenerateCSRFToken generates a CSRF token for the user and puts it in their session.
func GenerateCSRFToken(w http.ResponseWriter, r *http.Request, session *gorilla.Session) string {
token, ok := session.Values["csrf"].(string)
if !ok {
token := uuid.New()
session.Values["csrf"] = token.String()
session.Save(r, w)
}
return token
}

View File

@ -1,66 +0,0 @@
package responses
import (
"net/http"
"github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/render"
)
// Redirect sends an HTTP redirect response.
func Redirect(w http.ResponseWriter, location string) {
w.Header().Set("Location", location)
w.WriteHeader(http.StatusFound)
}
// NotFound sends a 404 response.
func NotFound(w http.ResponseWriter, r *http.Request, message ...string) {
if len(message) == 0 {
message = []string{"The page you were looking for was not found."}
}
w.WriteHeader(http.StatusNotFound)
err := render.RenderTemplate(w, r, ".errors/404", &render.Vars{
Message: message[0],
})
if err != nil {
log.Error(err.Error())
w.Write([]byte("Unrecoverable template error for NotFound()"))
}
}
// Forbidden sends an HTTP 403 Forbidden response.
func Forbidden(w http.ResponseWriter, r *http.Request, message ...string) {
w.WriteHeader(http.StatusForbidden)
err := render.RenderTemplate(w, r, ".errors/403", &render.Vars{
Message: message[0],
})
if err != nil {
log.Error(err.Error())
w.Write([]byte("Unrecoverable template error for Forbidden()"))
}
}
// Error sends an HTTP 500 Internal Server Error response.
func Error(w http.ResponseWriter, r *http.Request, message ...string) {
w.WriteHeader(http.StatusInternalServerError)
err := render.RenderTemplate(w, r, ".errors/500", &render.Vars{
Message: message[0],
})
if err != nil {
log.Error(err.Error())
w.Write([]byte("Unrecoverable template error for Error()"))
}
}
// BadRequest sends an HTTP 400 Bad Request.
func BadRequest(w http.ResponseWriter, r *http.Request, message ...string) {
w.WriteHeader(http.StatusBadRequest)
err := render.RenderTemplate(w, r, ".errors/400", &render.Vars{
Message: message[0],
})
if err != nil {
log.Error(err.Error())
w.Write([]byte("Unrecoverable template error for BadRequest()"))
}
}

View File

@ -0,0 +1,56 @@
package sessions
import (
"context"
"net/http"
"time"
"github.com/gorilla/sessions"
"github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/types"
)
// Store holds your cookie store information.
var Store sessions.Store
// SetSecretKey initializes a session cookie store with the secret key.
func SetSecretKey(keyPairs ...[]byte) {
Store = sessions.NewCookieStore(keyPairs...)
}
// Middleware gets the Gorilla session store and makes it available on the
// Request context.
//
// Middleware is the first custom middleware applied, so it takes the current
// datetime to make available later in the request and stores it on the request
// context.
func Middleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// Store the current datetime on the request context.
ctx := context.WithValue(r.Context(), types.StartTimeKey, time.Now())
// Get the Gorilla session and make it available in the request context.
session, _ := Store.Get(r, "session")
ctx = context.WithValue(ctx, types.SessionKey, session)
next(w, r.WithContext(ctx))
}
// Get returns the current request's session.
func Get(r *http.Request) *sessions.Session {
if r == nil {
panic("Session(*http.Request) with a nil argument!?")
}
ctx := r.Context()
if session, ok := ctx.Value(types.SessionKey).(*sessions.Session); ok {
return session
}
// If the session wasn't on the request, it means I broke something.
log.Error(
"Session(): didn't find session in request context! Getting it " +
"from the session store instead.",
)
session, _ := Store.Get(r, "session")
return session
}

View File

@ -0,0 +1,11 @@
package types
// Key is an integer enum for context.Context keys.
type Key int
// Key definitions.
const (
SessionKey Key = iota // The request's cookie session object.
UserKey // The request's user data for logged-in users.
StartTimeKey // HTTP request start time.
)

View File

@ -1,133 +0,0 @@
package core
import (
"context"
"errors"
"net/http"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/models/users"
)
type key int
const (
sessionKey key = iota
userKey
requestTimeKey
)
// SessionLoader gets the Gorilla session store and makes it available on the
// Request context.
//
// SessionLoader is the first custom middleware applied, so it takes the current
// datetime to make available later in the request and stores it on the request
// context.
func (b *Blog) SessionLoader(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// Store the current datetime on the request context.
ctx := context.WithValue(r.Context(), requestTimeKey, time.Now())
// Get the Gorilla session and make it available in the request context.
session, _ := b.store.Get(r, "session")
ctx = context.WithValue(ctx, sessionKey, session)
next(w, r.WithContext(ctx))
}
// Session returns the current request's session.
func (b *Blog) Session(r *http.Request) *sessions.Session {
if r == nil {
panic("Session(*http.Request) with a nil argument!?")
}
ctx := r.Context()
if session, ok := ctx.Value(sessionKey).(*sessions.Session); ok {
return session
}
log.Error(
"Session(): didn't find session in request context! Getting it " +
"from the session store instead.",
)
session, _ := b.store.Get(r, "session")
return session
}
// CSRFMiddleware enforces CSRF tokens on all POST requests.
func (b *Blog) CSRFMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if r.Method == "POST" {
session := b.Session(r)
token := b.GenerateCSRFToken(w, r, session)
if token != r.FormValue("_csrf") {
log.Error("CSRF Mismatch: expected %s, got %s", r.FormValue("_csrf"), token)
b.Forbidden(w, r, "Failed to validate CSRF token. Please try your request again.")
return
}
}
next(w, r)
}
// GenerateCSRFToken generates a CSRF token for the user and puts it in their session.
func (b *Blog) GenerateCSRFToken(w http.ResponseWriter, r *http.Request, session *sessions.Session) string {
token, ok := session.Values["csrf"].(string)
if !ok {
token := uuid.New()
session.Values["csrf"] = token.String()
session.Save(r, w)
}
return token
}
// CurrentUser returns the current user's object.
func (b *Blog) CurrentUser(r *http.Request) (*users.User, error) {
session := b.Session(r)
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
id := session.Values["user-id"].(int)
u, err := users.LoadReadonly(id)
u.IsAuthenticated = true
return u, err
}
return &users.User{
Admin: false,
}, errors.New("not authenticated")
}
// LoggedIn returns whether the current user is logged in to an account.
func (b *Blog) LoggedIn(r *http.Request) bool {
session := b.Session(r)
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
return true
}
return false
}
// AuthMiddleware loads the user's authentication state.
func (b *Blog) AuthMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
u, err := b.CurrentUser(r)
if err != nil {
next(w, r)
return
}
ctx := context.WithValue(r.Context(), userKey, u)
next(w, r.WithContext(ctx))
}
// LoginRequired is a middleware that requires a logged-in user.
func (b *Blog) LoginRequired(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
ctx := r.Context()
if user, ok := ctx.Value(userKey).(*users.User); ok {
if user.ID > 0 {
next(w, r)
return
}
}
log.Info("Redirect away!")
b.Redirect(w, "/login?next="+r.URL.Path)
}

View File

@ -23,7 +23,7 @@ func (b *Blog) PageHandler(w http.ResponseWriter, r *http.Request) {
// Restrict special paths. // Restrict special paths.
if strings.HasPrefix(strings.ToLower(path), "/.") { if strings.HasPrefix(strings.ToLower(path), "/.") {
b.Forbidden(w, r) b.Forbidden(w, r, "Forbidden")
return return
} }

View File

@ -6,11 +6,12 @@ import (
"github.com/kirsle/blog/core/internal/log" "github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/render" "github.com/kirsle/blog/core/internal/render"
"github.com/kirsle/blog/core/internal/sessions"
) )
// Flash adds a flash message to the user's session. // Flash adds a flash message to the user's session.
func (b *Blog) Flash(w http.ResponseWriter, r *http.Request, message string, args ...interface{}) { func (b *Blog) Flash(w http.ResponseWriter, r *http.Request, message string, args ...interface{}) {
session := b.Session(r) session := sessions.Get(r)
session.AddFlash(fmt.Sprintf(message, args...)) session.AddFlash(fmt.Sprintf(message, args...))
session.Save(r, w) session.Save(r, w)
} }
@ -34,14 +35,14 @@ func (b *Blog) Redirect(w http.ResponseWriter, location string) {
} }
// NotFound sends a 404 response. // NotFound sends a 404 response.
func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message ...string) { func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message string) {
if len(message) == 0 { if message == "" {
message = []string{"The page you were looking for was not found."} message = "The page you were looking for was not found."
} }
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
err := b.RenderTemplate(w, r, ".errors/404", render.Vars{ err := b.RenderTemplate(w, r, ".errors/404", render.Vars{
Message: message[0], Message: message,
}) })
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
@ -50,10 +51,10 @@ func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message ...strin
} }
// Forbidden sends an HTTP 403 Forbidden response. // Forbidden sends an HTTP 403 Forbidden response.
func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message ...string) { func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message string) {
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
err := b.RenderTemplate(w, r, ".errors/403", render.Vars{ err := b.RenderTemplate(w, r, ".errors/403", render.Vars{
Message: message[0], Message: message,
}) })
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
@ -62,10 +63,10 @@ func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message ...stri
} }
// Error sends an HTTP 500 Internal Server Error response. // Error sends an HTTP 500 Internal Server Error response.
func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message ...string) { func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message string) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
err := b.RenderTemplate(w, r, ".errors/500", render.Vars{ err := b.RenderTemplate(w, r, ".errors/500", render.Vars{
Message: message[0], Message: message,
}) })
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
@ -74,10 +75,10 @@ func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message ...string)
} }
// BadRequest sends an HTTP 400 Bad Request. // BadRequest sends an HTTP 400 Bad Request.
func (b *Blog) BadRequest(w http.ResponseWriter, r *http.Request, message ...string) { func (b *Blog) BadRequest(w http.ResponseWriter, r *http.Request, message string) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
err := b.RenderTemplate(w, r, ".errors/400", render.Vars{ err := b.RenderTemplate(w, r, ".errors/400", render.Vars{
Message: message[0], Message: message,
}) })
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())

View File

@ -8,9 +8,13 @@ import (
"time" "time"
"github.com/kirsle/blog/core/internal/forms" "github.com/kirsle/blog/core/internal/forms"
"github.com/kirsle/blog/core/internal/middleware"
"github.com/kirsle/blog/core/internal/middleware/auth"
"github.com/kirsle/blog/core/internal/models/settings" "github.com/kirsle/blog/core/internal/models/settings"
"github.com/kirsle/blog/core/internal/models/users" "github.com/kirsle/blog/core/internal/models/users"
"github.com/kirsle/blog/core/internal/render" "github.com/kirsle/blog/core/internal/render"
"github.com/kirsle/blog/core/internal/sessions"
"github.com/kirsle/blog/core/internal/types"
) )
// Vars is an interface to implement by the templates to pass their own custom // Vars is an interface to implement by the templates to pass their own custom
@ -66,11 +70,11 @@ func (b *Blog) LoadDefaults(v render.Vars, r *http.Request) render.Vars {
v.SetupNeeded = true v.SetupNeeded = true
} }
v.Request = r v.Request = r
v.RequestTime = r.Context().Value(requestTimeKey).(time.Time) v.RequestTime = r.Context().Value(types.StartTimeKey).(time.Time)
v.Title = s.Site.Title v.Title = s.Site.Title
v.Path = r.URL.Path v.Path = r.URL.Path
user, err := b.CurrentUser(r) user, err := auth.CurrentUser(r)
v.CurrentUser = user v.CurrentUser = user
v.LoggedIn = err == nil v.LoggedIn = err == nil
@ -109,7 +113,7 @@ func (b *Blog) RenderTemplate(w http.ResponseWriter, r *http.Request, path strin
vars = b.LoadDefaults(vars, r) vars = b.LoadDefaults(vars, r)
// Add any flashed messages from the endpoint controllers. // Add any flashed messages from the endpoint controllers.
session := b.Session(r) session := sessions.Get(r)
if flashes := session.Flashes(); len(flashes) > 0 { if flashes := session.Flashes(); len(flashes) > 0 {
for _, flash := range flashes { for _, flash := range flashes {
_ = flash _ = flash
@ -119,7 +123,7 @@ func (b *Blog) RenderTemplate(w http.ResponseWriter, r *http.Request, path strin
} }
vars.RequestDuration = time.Now().Sub(vars.RequestTime) vars.RequestDuration = time.Now().Sub(vars.RequestTime)
vars.CSRF = b.GenerateCSRFToken(w, r, session) vars.CSRF = middleware.GenerateCSRFToken(w, r, session)
vars.Editable = !strings.HasPrefix(path, "admin/") vars.Editable = !strings.HasPrefix(path, "admin/")
return render.Template(w, path, render.Config{ return render.Template(w, path, render.Config{
@ -140,8 +144,8 @@ func (b *Blog) TemplateFuncs(w http.ResponseWriter, r *http.Request, inject map[
return template.HTML("[RenderComments Error: need both http.ResponseWriter and http.Request]") return template.HTML("[RenderComments Error: need both http.ResponseWriter and http.Request]")
} }
session := b.Session(r) session := sessions.Get(r)
csrf := b.GenerateCSRFToken(w, r, session) csrf := middleware.GenerateCSRFToken(w, r, session)
return b.RenderComments(session, csrf, r.URL.Path, subject, ids...) return b.RenderComments(session, csrf, r.URL.Path, subject, ids...)
}, },
} }