Eviscerate the last of the middleware into sub-packages
This commit is contained in:
parent
e393b1880f
commit
60ccaf7b35
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/kirsle/blog/core/internal/forms"
|
||||
"github.com/kirsle/blog/core/internal/middleware/auth"
|
||||
"github.com/kirsle/blog/core/internal/models/settings"
|
||||
"github.com/kirsle/blog/core/internal/render"
|
||||
"github.com/urfave/negroni"
|
||||
|
@ -24,7 +25,7 @@ func (b *Blog) AdminRoutes(r *mux.Router) {
|
|||
adminRouter.HandleFunc("/editor", b.EditorHandler)
|
||||
// r.HandleFunc("/admin", b.AdminHandler)
|
||||
r.PathPrefix("/admin").Handler(negroni.New(
|
||||
negroni.HandlerFunc(b.LoginRequired),
|
||||
negroni.HandlerFunc(auth.LoginRequired(b.MustLogin)),
|
||||
negroni.Wrap(adminRouter),
|
||||
))
|
||||
}
|
||||
|
@ -48,8 +49,8 @@ func (b *Blog) EditorHandler(w http.ResponseWriter, r *http.Request) {
|
|||
var (
|
||||
fp string
|
||||
fromCore = r.FormValue("from") == "core"
|
||||
saving = r.FormValue("action") == "save"
|
||||
deleting = r.FormValue("action") == "delete"
|
||||
saving = r.FormValue("action") == ActionSave
|
||||
deleting = r.FormValue("action") == ActionDelete
|
||||
body = []byte{}
|
||||
)
|
||||
|
||||
|
|
17
core/auth.go
17
core/auth.go
|
@ -7,7 +7,9 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/kirsle/blog/core/internal/forms"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/middleware/auth"
|
||||
"github.com/kirsle/blog/core/internal/models/users"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
)
|
||||
|
||||
// AuthRoutes attaches the auth routes to the app.
|
||||
|
@ -17,9 +19,16 @@ func (b *Blog) AuthRoutes(r *mux.Router) {
|
|||
r.HandleFunc("/account", b.AccountHandler)
|
||||
}
|
||||
|
||||
// MustLogin handles errors from the LoginRequired middleware by redirecting
|
||||
// the user to the login page.
|
||||
func (b *Blog) MustLogin(w http.ResponseWriter, r *http.Request) {
|
||||
log.Info("MustLogin for %s", r.URL.Path)
|
||||
b.Redirect(w, "/login?next="+r.URL.Path)
|
||||
}
|
||||
|
||||
// Login logs the browser in as the given user.
|
||||
func (b *Blog) Login(w http.ResponseWriter, r *http.Request, u *users.User) error {
|
||||
session, err := b.store.Get(r, "session") // TODO session name
|
||||
session, err := sessions.Store.Get(r, "session") // TODO session name
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -78,7 +87,7 @@ func (b *Blog) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// LogoutHandler logs the user out and redirects to the home page.
|
||||
func (b *Blog) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
session, _ := b.store.Get(r, "session")
|
||||
session, _ := sessions.Store.Get(r, "session")
|
||||
delete(session.Values, "logged-in")
|
||||
delete(session.Values, "user-id")
|
||||
session.Save(r, w)
|
||||
|
@ -87,11 +96,11 @@ func (b *Blog) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// AccountHandler shows the account settings page.
|
||||
func (b *Blog) AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if !b.LoggedIn(r) {
|
||||
if !auth.LoggedIn(r) {
|
||||
b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!")
|
||||
return
|
||||
}
|
||||
currentUser, err := b.CurrentUser(r)
|
||||
currentUser, err := auth.CurrentUser(r)
|
||||
if err != nil {
|
||||
b.FlashAndRedirect(w, r, "/login?next=/account", "You must be logged in to do that!!")
|
||||
return
|
||||
|
|
22
core/blog.go
22
core/blog.go
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/markdown"
|
||||
"github.com/kirsle/blog/core/internal/middleware/auth"
|
||||
"github.com/kirsle/blog/core/internal/models/comments"
|
||||
"github.com/kirsle/blog/core/internal/models/posts"
|
||||
"github.com/kirsle/blog/core/internal/models/settings"
|
||||
|
@ -76,19 +77,10 @@ func (b *Blog) BlogRoutes(r *mux.Router) {
|
|||
loginRouter.HandleFunc("/blog/private", b.PrivatePosts)
|
||||
r.PathPrefix("/blog").Handler(
|
||||
negroni.New(
|
||||
negroni.HandlerFunc(b.LoginRequired),
|
||||
negroni.HandlerFunc(auth.LoginRequired(b.MustLogin)),
|
||||
negroni.Wrap(loginRouter),
|
||||
),
|
||||
)
|
||||
|
||||
adminRouter := mux.NewRouter().PathPrefix("/admin").Subrouter().StrictSlash(false)
|
||||
r.HandleFunc("/admin", b.AdminHandler) // so as to not be "/admin/"
|
||||
adminRouter.HandleFunc("/settings", b.SettingsHandler)
|
||||
adminRouter.PathPrefix("/").HandlerFunc(b.PageHandler)
|
||||
r.PathPrefix("/admin").Handler(negroni.New(
|
||||
negroni.HandlerFunc(b.LoginRequired),
|
||||
negroni.Wrap(adminRouter),
|
||||
))
|
||||
}
|
||||
|
||||
// RSSHandler renders an RSS feed from the blog.
|
||||
|
@ -214,7 +206,7 @@ func (b *Blog) RecentPosts(r *http.Request, tag, privacy string) []posts.Post {
|
|||
}
|
||||
} else {
|
||||
// Exclude certain posts in generic index views.
|
||||
if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !b.LoggedIn(r) {
|
||||
if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !auth.LoggedIn(r) {
|
||||
continue
|
||||
} else if post.Privacy == DRAFT {
|
||||
continue
|
||||
|
@ -370,7 +362,7 @@ func (b *Blog) BlogArchive(w http.ResponseWriter, r *http.Request) {
|
|||
byMonth := map[string]*Archive{}
|
||||
for _, post := range idx.Posts {
|
||||
// Exclude certain posts
|
||||
if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !b.LoggedIn(r) {
|
||||
if (post.Privacy == PRIVATE || post.Privacy == UNLISTED) && !auth.LoggedIn(r) {
|
||||
continue
|
||||
} else if post.Privacy == DRAFT {
|
||||
continue
|
||||
|
@ -416,8 +408,8 @@ func (b *Blog) viewPost(w http.ResponseWriter, r *http.Request, fragment string)
|
|||
|
||||
// Handle post privacy.
|
||||
if post.Privacy == PRIVATE || post.Privacy == DRAFT {
|
||||
if !b.LoggedIn(r) {
|
||||
b.NotFound(w, r)
|
||||
if !auth.LoggedIn(r) {
|
||||
b.NotFound(w, r, "That post is not public.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -517,7 +509,7 @@ func (b *Blog) EditBlog(w http.ResponseWriter, r *http.Request) {
|
|||
if err := post.Validate(); err != nil {
|
||||
v.Error = err
|
||||
} else {
|
||||
author, _ := b.CurrentUser(r)
|
||||
author, _ := auth.CurrentUser(r)
|
||||
post.AuthorID = author.ID
|
||||
|
||||
post.Updated = time.Now().UTC()
|
||||
|
|
|
@ -10,12 +10,14 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
gorilla "github.com/gorilla/sessions"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/markdown"
|
||||
"github.com/kirsle/blog/core/internal/middleware/auth"
|
||||
"github.com/kirsle/blog/core/internal/models/comments"
|
||||
"github.com/kirsle/blog/core/internal/models/users"
|
||||
"github.com/kirsle/blog/core/internal/render"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
)
|
||||
|
||||
// CommentRoutes attaches the comment routes to the app.
|
||||
|
@ -37,7 +39,7 @@ type CommentMeta struct {
|
|||
}
|
||||
|
||||
// RenderComments renders a comment form partial and returns the HTML.
|
||||
func (b *Blog) RenderComments(session *sessions.Session, csrfToken, url, subject string, ids ...string) template.HTML {
|
||||
func (b *Blog) RenderComments(session *gorilla.Session, csrfToken, url, subject string, ids ...string) template.HTML {
|
||||
id := strings.Join(ids, "-")
|
||||
|
||||
// Load their cached name and email if they posted a comment before.
|
||||
|
@ -142,7 +144,7 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
v := NewVars()
|
||||
currentUser, _ := b.CurrentUser(r)
|
||||
currentUser, _ := auth.CurrentUser(r)
|
||||
editToken := b.GetEditToken(w, r)
|
||||
submit := r.FormValue("submit")
|
||||
|
||||
|
@ -193,21 +195,21 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Cache their name and email in their session.
|
||||
session := b.Session(r)
|
||||
session := sessions.Get(r)
|
||||
session.Values["c.name"] = c.Name
|
||||
session.Values["c.email"] = c.Email
|
||||
session.Save(r, w)
|
||||
|
||||
// Previewing, deleting, or posting?
|
||||
switch submit {
|
||||
case "preview", "delete":
|
||||
case ActionPreview, ActionDelete:
|
||||
if !c.Editing && currentUser.IsAuthenticated {
|
||||
c.Name = currentUser.Name
|
||||
c.Email = currentUser.Email
|
||||
c.LoadAvatar()
|
||||
}
|
||||
c.HTML = template.HTML(markdown.RenderMarkdown(c.Body))
|
||||
case "post":
|
||||
case ActionPost:
|
||||
if err := c.Validate(); err != nil {
|
||||
v.Error = err
|
||||
} else {
|
||||
|
@ -251,7 +253,7 @@ func (b *Blog) CommentHandler(w http.ResponseWriter, r *http.Request) {
|
|||
v.Data["Thread"] = t
|
||||
v.Data["Comment"] = c
|
||||
v.Data["Editing"] = c.Editing
|
||||
v.Data["Deleting"] = submit == "delete"
|
||||
v.Data["Deleting"] = submit == ActionDelete
|
||||
|
||||
b.RenderTemplate(w, r, "comments/index.gohtml", v)
|
||||
}
|
||||
|
@ -295,7 +297,7 @@ func (b *Blog) QuickDeleteHandler(w http.ResponseWriter, r *http.Request) {
|
|||
thread := r.URL.Query().Get("t")
|
||||
token := r.URL.Query().Get("d")
|
||||
if thread == "" || token == "" {
|
||||
b.BadRequest(w, r)
|
||||
b.BadRequest(w, r, "Bad Request")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -315,7 +317,7 @@ func (b *Blog) QuickDeleteHandler(w http.ResponseWriter, r *http.Request) {
|
|||
// GetEditToken gets or generates an edit token from the user's session, which
|
||||
// allows a user to edit their comment for a short while after they post it.
|
||||
func (b *Blog) GetEditToken(w http.ResponseWriter, r *http.Request) string {
|
||||
session := b.Session(r)
|
||||
session := sessions.Get(r)
|
||||
if token, ok := session.Values["c.token"].(string); ok && len(token) > 0 {
|
||||
return token
|
||||
}
|
||||
|
|
|
@ -19,3 +19,11 @@ const (
|
|||
MARKDOWN ContentType = "markdown"
|
||||
HTML ContentType = "html"
|
||||
)
|
||||
|
||||
// Common form actions.
|
||||
const (
|
||||
ActionSave = "save"
|
||||
ActionDelete = "delete"
|
||||
ActionPreview = "preview"
|
||||
ActionPost = "post"
|
||||
)
|
||||
|
|
13
core/core.go
13
core/core.go
|
@ -7,14 +7,16 @@ import (
|
|||
"path/filepath"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/markdown"
|
||||
"github.com/kirsle/blog/core/internal/middleware"
|
||||
"github.com/kirsle/blog/core/internal/middleware/auth"
|
||||
"github.com/kirsle/blog/core/internal/models/comments"
|
||||
"github.com/kirsle/blog/core/internal/models/posts"
|
||||
"github.com/kirsle/blog/core/internal/models/settings"
|
||||
"github.com/kirsle/blog/core/internal/models/users"
|
||||
"github.com/kirsle/blog/core/internal/render"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
"github.com/kirsle/blog/jsondb"
|
||||
"github.com/kirsle/blog/jsondb/caches"
|
||||
"github.com/kirsle/blog/jsondb/caches/null"
|
||||
|
@ -38,7 +40,6 @@ type Blog struct {
|
|||
// Web app objects.
|
||||
n *negroni.Negroni // Negroni middleware manager
|
||||
r *mux.Router // Router
|
||||
store sessions.Store
|
||||
}
|
||||
|
||||
// New initializes the Blog application.
|
||||
|
@ -73,7 +74,7 @@ func (b *Blog) Configure() {
|
|||
render.DocumentRoot = &b.DocumentRoot
|
||||
|
||||
// Initialize the session cookie store.
|
||||
b.store = sessions.NewCookieStore([]byte(config.Security.SecretKey))
|
||||
sessions.SetSecretKey([]byte(config.Security.SecretKey))
|
||||
users.HashCost = config.Security.HashCost
|
||||
|
||||
// Initialize the rest of the models.
|
||||
|
@ -120,9 +121,9 @@ func (b *Blog) SetupHTTP() {
|
|||
n := negroni.New(
|
||||
negroni.NewRecovery(),
|
||||
negroni.NewLogger(),
|
||||
negroni.HandlerFunc(b.SessionLoader),
|
||||
negroni.HandlerFunc(b.CSRFMiddleware),
|
||||
negroni.HandlerFunc(b.AuthMiddleware),
|
||||
negroni.HandlerFunc(sessions.Middleware),
|
||||
negroni.HandlerFunc(middleware.CSRF(b.Forbidden)),
|
||||
negroni.HandlerFunc(auth.Middleware),
|
||||
)
|
||||
n.UseHandler(r)
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@ package core
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/kirsle/blog/core/internal/forms"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/models/settings"
|
||||
"github.com/kirsle/blog/core/internal/models/users"
|
||||
"github.com/kirsle/blog/core/internal/render"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
)
|
||||
|
||||
// SetupHandler is the initial blog setup route.
|
||||
|
@ -41,7 +41,7 @@ func (b *Blog) SetupHandler(w http.ResponseWriter, r *http.Request) {
|
|||
s.Save()
|
||||
|
||||
// Re-initialize the cookie store with the new secret key.
|
||||
b.store = sessions.NewCookieStore([]byte(s.Security.SecretKey))
|
||||
sessions.SetSecretKey([]byte(s.Security.SecretKey))
|
||||
|
||||
log.Info("Creating admin account %s", form.Username)
|
||||
user := &users.User{
|
||||
|
|
67
core/internal/middleware/auth/auth.go
Normal file
67
core/internal/middleware/auth/auth.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/models/users"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
"github.com/kirsle/blog/core/internal/types"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// CurrentUser returns the current user's object.
|
||||
func CurrentUser(r *http.Request) (*users.User, error) {
|
||||
session := sessions.Get(r)
|
||||
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
|
||||
id := session.Values["user-id"].(int)
|
||||
u, err := users.LoadReadonly(id)
|
||||
u.IsAuthenticated = true
|
||||
return u, err
|
||||
}
|
||||
|
||||
return &users.User{
|
||||
Admin: false,
|
||||
}, errors.New("not authenticated")
|
||||
}
|
||||
|
||||
// LoggedIn returns whether the current user is logged in to an account.
|
||||
func LoggedIn(r *http.Request) bool {
|
||||
session := sessions.Get(r)
|
||||
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LoginRequired is a middleware that requires a logged-in user.
|
||||
func LoginRequired(onError http.HandlerFunc) negroni.HandlerFunc {
|
||||
middleware := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
ctx := r.Context()
|
||||
if user, ok := ctx.Value(types.UserKey).(*users.User); ok {
|
||||
if user.ID > 0 {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Redirect away!")
|
||||
onError(w, r)
|
||||
}
|
||||
|
||||
return middleware
|
||||
}
|
||||
|
||||
// Middleware loads the user's authentication state from their session cookie.
|
||||
func Middleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
u, err := CurrentUser(r)
|
||||
if err != nil {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), types.UserKey, u)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
55
core/internal/middleware/csrf.go
Normal file
55
core/internal/middleware/csrf.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
gorilla "github.com/gorilla/sessions"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// CSRF is a middleware generator that enforces CSRF tokens on all POST requests.
|
||||
func CSRF(onError func(http.ResponseWriter, *http.Request, string)) negroni.HandlerFunc {
|
||||
middleware := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if r.Method == "POST" {
|
||||
session := sessions.Get(r)
|
||||
token := GenerateCSRFToken(w, r, session)
|
||||
if token != r.FormValue("_csrf") {
|
||||
log.Error("CSRF Mismatch: expected %s, got %s", r.FormValue("_csrf"), token)
|
||||
onError(w, r, "Failed to validate CSRF token. Please try your request again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
return middleware
|
||||
}
|
||||
|
||||
// ExampleCSRF shows how to use the CSRF handler.
|
||||
func ExampleCSRF() {
|
||||
// Your error handling for CSRF failures.
|
||||
onError := func(w http.ResponseWriter, r *http.Request, message string) {
|
||||
w.Write([]byte("CSRF Error: " + message))
|
||||
}
|
||||
|
||||
// Attach the middleware.
|
||||
_ = negroni.New(
|
||||
negroni.NewRecovery(),
|
||||
negroni.NewLogger(),
|
||||
negroni.HandlerFunc(CSRF(onError)),
|
||||
)
|
||||
}
|
||||
|
||||
// GenerateCSRFToken generates a CSRF token for the user and puts it in their session.
|
||||
func GenerateCSRFToken(w http.ResponseWriter, r *http.Request, session *gorilla.Session) string {
|
||||
token, ok := session.Values["csrf"].(string)
|
||||
if !ok {
|
||||
token := uuid.New()
|
||||
session.Values["csrf"] = token.String()
|
||||
session.Save(r, w)
|
||||
}
|
||||
return token
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
package responses
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/render"
|
||||
)
|
||||
|
||||
// Redirect sends an HTTP redirect response.
|
||||
func Redirect(w http.ResponseWriter, location string) {
|
||||
w.Header().Set("Location", location)
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
// NotFound sends a 404 response.
|
||||
func NotFound(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
if len(message) == 0 {
|
||||
message = []string{"The page you were looking for was not found."}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
err := render.RenderTemplate(w, r, ".errors/404", &render.Vars{
|
||||
Message: message[0],
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
w.Write([]byte("Unrecoverable template error for NotFound()"))
|
||||
}
|
||||
}
|
||||
|
||||
// Forbidden sends an HTTP 403 Forbidden response.
|
||||
func Forbidden(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
err := render.RenderTemplate(w, r, ".errors/403", &render.Vars{
|
||||
Message: message[0],
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
w.Write([]byte("Unrecoverable template error for Forbidden()"))
|
||||
}
|
||||
}
|
||||
|
||||
// Error sends an HTTP 500 Internal Server Error response.
|
||||
func Error(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
err := render.RenderTemplate(w, r, ".errors/500", &render.Vars{
|
||||
Message: message[0],
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
w.Write([]byte("Unrecoverable template error for Error()"))
|
||||
}
|
||||
}
|
||||
|
||||
// BadRequest sends an HTTP 400 Bad Request.
|
||||
func BadRequest(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
err := render.RenderTemplate(w, r, ".errors/400", &render.Vars{
|
||||
Message: message[0],
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
w.Write([]byte("Unrecoverable template error for BadRequest()"))
|
||||
}
|
||||
}
|
56
core/internal/sessions/sessions.go
Normal file
56
core/internal/sessions/sessions.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/types"
|
||||
)
|
||||
|
||||
// Store holds your cookie store information.
|
||||
var Store sessions.Store
|
||||
|
||||
// SetSecretKey initializes a session cookie store with the secret key.
|
||||
func SetSecretKey(keyPairs ...[]byte) {
|
||||
Store = sessions.NewCookieStore(keyPairs...)
|
||||
}
|
||||
|
||||
// Middleware gets the Gorilla session store and makes it available on the
|
||||
// Request context.
|
||||
//
|
||||
// Middleware is the first custom middleware applied, so it takes the current
|
||||
// datetime to make available later in the request and stores it on the request
|
||||
// context.
|
||||
func Middleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
// Store the current datetime on the request context.
|
||||
ctx := context.WithValue(r.Context(), types.StartTimeKey, time.Now())
|
||||
|
||||
// Get the Gorilla session and make it available in the request context.
|
||||
session, _ := Store.Get(r, "session")
|
||||
ctx = context.WithValue(ctx, types.SessionKey, session)
|
||||
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
// Get returns the current request's session.
|
||||
func Get(r *http.Request) *sessions.Session {
|
||||
if r == nil {
|
||||
panic("Session(*http.Request) with a nil argument!?")
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
if session, ok := ctx.Value(types.SessionKey).(*sessions.Session); ok {
|
||||
return session
|
||||
}
|
||||
|
||||
// If the session wasn't on the request, it means I broke something.
|
||||
log.Error(
|
||||
"Session(): didn't find session in request context! Getting it " +
|
||||
"from the session store instead.",
|
||||
)
|
||||
session, _ := Store.Get(r, "session")
|
||||
return session
|
||||
}
|
11
core/internal/types/context.go
Normal file
11
core/internal/types/context.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package types
|
||||
|
||||
// Key is an integer enum for context.Context keys.
|
||||
type Key int
|
||||
|
||||
// Key definitions.
|
||||
const (
|
||||
SessionKey Key = iota // The request's cookie session object.
|
||||
UserKey // The request's user data for logged-in users.
|
||||
StartTimeKey // HTTP request start time.
|
||||
)
|
|
@ -1,133 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/models/users"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
sessionKey key = iota
|
||||
userKey
|
||||
requestTimeKey
|
||||
)
|
||||
|
||||
// SessionLoader gets the Gorilla session store and makes it available on the
|
||||
// Request context.
|
||||
//
|
||||
// SessionLoader is the first custom middleware applied, so it takes the current
|
||||
// datetime to make available later in the request and stores it on the request
|
||||
// context.
|
||||
func (b *Blog) SessionLoader(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
// Store the current datetime on the request context.
|
||||
ctx := context.WithValue(r.Context(), requestTimeKey, time.Now())
|
||||
|
||||
// Get the Gorilla session and make it available in the request context.
|
||||
session, _ := b.store.Get(r, "session")
|
||||
ctx = context.WithValue(ctx, sessionKey, session)
|
||||
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
// Session returns the current request's session.
|
||||
func (b *Blog) Session(r *http.Request) *sessions.Session {
|
||||
if r == nil {
|
||||
panic("Session(*http.Request) with a nil argument!?")
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
if session, ok := ctx.Value(sessionKey).(*sessions.Session); ok {
|
||||
return session
|
||||
}
|
||||
|
||||
log.Error(
|
||||
"Session(): didn't find session in request context! Getting it " +
|
||||
"from the session store instead.",
|
||||
)
|
||||
session, _ := b.store.Get(r, "session")
|
||||
return session
|
||||
}
|
||||
|
||||
// CSRFMiddleware enforces CSRF tokens on all POST requests.
|
||||
func (b *Blog) CSRFMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if r.Method == "POST" {
|
||||
session := b.Session(r)
|
||||
token := b.GenerateCSRFToken(w, r, session)
|
||||
if token != r.FormValue("_csrf") {
|
||||
log.Error("CSRF Mismatch: expected %s, got %s", r.FormValue("_csrf"), token)
|
||||
b.Forbidden(w, r, "Failed to validate CSRF token. Please try your request again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
// GenerateCSRFToken generates a CSRF token for the user and puts it in their session.
|
||||
func (b *Blog) GenerateCSRFToken(w http.ResponseWriter, r *http.Request, session *sessions.Session) string {
|
||||
token, ok := session.Values["csrf"].(string)
|
||||
if !ok {
|
||||
token := uuid.New()
|
||||
session.Values["csrf"] = token.String()
|
||||
session.Save(r, w)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// CurrentUser returns the current user's object.
|
||||
func (b *Blog) CurrentUser(r *http.Request) (*users.User, error) {
|
||||
session := b.Session(r)
|
||||
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
|
||||
id := session.Values["user-id"].(int)
|
||||
u, err := users.LoadReadonly(id)
|
||||
u.IsAuthenticated = true
|
||||
return u, err
|
||||
}
|
||||
|
||||
return &users.User{
|
||||
Admin: false,
|
||||
}, errors.New("not authenticated")
|
||||
}
|
||||
|
||||
// LoggedIn returns whether the current user is logged in to an account.
|
||||
func (b *Blog) LoggedIn(r *http.Request) bool {
|
||||
session := b.Session(r)
|
||||
if loggedIn, ok := session.Values["logged-in"].(bool); ok && loggedIn {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AuthMiddleware loads the user's authentication state.
|
||||
func (b *Blog) AuthMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
u, err := b.CurrentUser(r)
|
||||
if err != nil {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userKey, u)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
// LoginRequired is a middleware that requires a logged-in user.
|
||||
func (b *Blog) LoginRequired(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
ctx := r.Context()
|
||||
if user, ok := ctx.Value(userKey).(*users.User); ok {
|
||||
if user.ID > 0 {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Redirect away!")
|
||||
b.Redirect(w, "/login?next="+r.URL.Path)
|
||||
}
|
|
@ -23,7 +23,7 @@ func (b *Blog) PageHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Restrict special paths.
|
||||
if strings.HasPrefix(strings.ToLower(path), "/.") {
|
||||
b.Forbidden(w, r)
|
||||
b.Forbidden(w, r, "Forbidden")
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -6,11 +6,12 @@ import (
|
|||
|
||||
"github.com/kirsle/blog/core/internal/log"
|
||||
"github.com/kirsle/blog/core/internal/render"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
)
|
||||
|
||||
// Flash adds a flash message to the user's session.
|
||||
func (b *Blog) Flash(w http.ResponseWriter, r *http.Request, message string, args ...interface{}) {
|
||||
session := b.Session(r)
|
||||
session := sessions.Get(r)
|
||||
session.AddFlash(fmt.Sprintf(message, args...))
|
||||
session.Save(r, w)
|
||||
}
|
||||
|
@ -34,14 +35,14 @@ func (b *Blog) Redirect(w http.ResponseWriter, location string) {
|
|||
}
|
||||
|
||||
// NotFound sends a 404 response.
|
||||
func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
if len(message) == 0 {
|
||||
message = []string{"The page you were looking for was not found."}
|
||||
func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message string) {
|
||||
if message == "" {
|
||||
message = "The page you were looking for was not found."
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
err := b.RenderTemplate(w, r, ".errors/404", render.Vars{
|
||||
Message: message[0],
|
||||
Message: message,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
|
@ -50,10 +51,10 @@ func (b *Blog) NotFound(w http.ResponseWriter, r *http.Request, message ...strin
|
|||
}
|
||||
|
||||
// Forbidden sends an HTTP 403 Forbidden response.
|
||||
func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message string) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
err := b.RenderTemplate(w, r, ".errors/403", render.Vars{
|
||||
Message: message[0],
|
||||
Message: message,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
|
@ -62,10 +63,10 @@ func (b *Blog) Forbidden(w http.ResponseWriter, r *http.Request, message ...stri
|
|||
}
|
||||
|
||||
// Error sends an HTTP 500 Internal Server Error response.
|
||||
func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message string) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
err := b.RenderTemplate(w, r, ".errors/500", render.Vars{
|
||||
Message: message[0],
|
||||
Message: message,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
|
@ -74,10 +75,10 @@ func (b *Blog) Error(w http.ResponseWriter, r *http.Request, message ...string)
|
|||
}
|
||||
|
||||
// BadRequest sends an HTTP 400 Bad Request.
|
||||
func (b *Blog) BadRequest(w http.ResponseWriter, r *http.Request, message ...string) {
|
||||
func (b *Blog) BadRequest(w http.ResponseWriter, r *http.Request, message string) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
err := b.RenderTemplate(w, r, ".errors/400", render.Vars{
|
||||
Message: message[0],
|
||||
Message: message,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
|
|
|
@ -8,9 +8,13 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/kirsle/blog/core/internal/forms"
|
||||
"github.com/kirsle/blog/core/internal/middleware"
|
||||
"github.com/kirsle/blog/core/internal/middleware/auth"
|
||||
"github.com/kirsle/blog/core/internal/models/settings"
|
||||
"github.com/kirsle/blog/core/internal/models/users"
|
||||
"github.com/kirsle/blog/core/internal/render"
|
||||
"github.com/kirsle/blog/core/internal/sessions"
|
||||
"github.com/kirsle/blog/core/internal/types"
|
||||
)
|
||||
|
||||
// Vars is an interface to implement by the templates to pass their own custom
|
||||
|
@ -66,11 +70,11 @@ func (b *Blog) LoadDefaults(v render.Vars, r *http.Request) render.Vars {
|
|||
v.SetupNeeded = true
|
||||
}
|
||||
v.Request = r
|
||||
v.RequestTime = r.Context().Value(requestTimeKey).(time.Time)
|
||||
v.RequestTime = r.Context().Value(types.StartTimeKey).(time.Time)
|
||||
v.Title = s.Site.Title
|
||||
v.Path = r.URL.Path
|
||||
|
||||
user, err := b.CurrentUser(r)
|
||||
user, err := auth.CurrentUser(r)
|
||||
v.CurrentUser = user
|
||||
v.LoggedIn = err == nil
|
||||
|
||||
|
@ -109,7 +113,7 @@ func (b *Blog) RenderTemplate(w http.ResponseWriter, r *http.Request, path strin
|
|||
vars = b.LoadDefaults(vars, r)
|
||||
|
||||
// Add any flashed messages from the endpoint controllers.
|
||||
session := b.Session(r)
|
||||
session := sessions.Get(r)
|
||||
if flashes := session.Flashes(); len(flashes) > 0 {
|
||||
for _, flash := range flashes {
|
||||
_ = flash
|
||||
|
@ -119,7 +123,7 @@ func (b *Blog) RenderTemplate(w http.ResponseWriter, r *http.Request, path strin
|
|||
}
|
||||
|
||||
vars.RequestDuration = time.Now().Sub(vars.RequestTime)
|
||||
vars.CSRF = b.GenerateCSRFToken(w, r, session)
|
||||
vars.CSRF = middleware.GenerateCSRFToken(w, r, session)
|
||||
vars.Editable = !strings.HasPrefix(path, "admin/")
|
||||
|
||||
return render.Template(w, path, render.Config{
|
||||
|
@ -140,8 +144,8 @@ func (b *Blog) TemplateFuncs(w http.ResponseWriter, r *http.Request, inject map[
|
|||
return template.HTML("[RenderComments Error: need both http.ResponseWriter and http.Request]")
|
||||
}
|
||||
|
||||
session := b.Session(r)
|
||||
csrf := b.GenerateCSRFToken(w, r, session)
|
||||
session := sessions.Get(r)
|
||||
csrf := middleware.GenerateCSRFToken(w, r, session)
|
||||
return b.RenderComments(session, csrf, r.URL.Path, subject, ids...)
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user