blog/core/internal/middleware/csrf.go

56 lines
1.5 KiB
Go

package middleware
import (
"net/http"
"github.com/google/uuid"
gorilla "github.com/gorilla/sessions"
"github.com/kirsle/blog/core/internal/log"
"github.com/kirsle/blog/core/internal/sessions"
"github.com/urfave/negroni"
)
// CSRF is a middleware generator that enforces CSRF tokens on all POST requests.
func CSRF(onError func(http.ResponseWriter, *http.Request, string)) negroni.HandlerFunc {
middleware := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if r.Method == "POST" {
session := sessions.Get(r)
token := GenerateCSRFToken(w, r, session)
if token != r.FormValue("_csrf") {
log.Error("CSRF Mismatch: expected %s, got %s", r.FormValue("_csrf"), token)
onError(w, r, "Failed to validate CSRF token. Please try your request again.")
return
}
}
next(w, r)
}
return middleware
}
// ExampleCSRF shows how to use the CSRF handler.
func ExampleCSRF() {
// Your error handling for CSRF failures.
onError := func(w http.ResponseWriter, r *http.Request, message string) {
w.Write([]byte("CSRF Error: " + message))
}
// Attach the middleware.
_ = negroni.New(
negroni.NewRecovery(),
negroni.NewLogger(),
negroni.HandlerFunc(CSRF(onError)),
)
}
// GenerateCSRFToken generates a CSRF token for the user and puts it in their session.
func GenerateCSRFToken(w http.ResponseWriter, r *http.Request, session *gorilla.Session) string {
token, ok := session.Values["csrf"].(string)
if !ok {
token := uuid.New()
session.Values["csrf"] = token.String()
session.Save(r, w)
}
return token
}