gophertype/pkg/middleware/csrf.go

54 lines
1.5 KiB
Go
Raw Normal View History

package middleware
import (
2019-11-27 00:54:02 +00:00
"context"
"net/http"
"git.kirsle.net/apps/gophertype/pkg/constants"
"git.kirsle.net/apps/gophertype/pkg/responses"
2019-11-27 00:54:02 +00:00
"git.kirsle.net/apps/gophertype/pkg/session"
uuid "github.com/satori/go.uuid"
)
// CSRF prevents Cross-Site Request Forgery.
// All "POST" requests are required to have an "_csrf" variable passed in which
// matches the "csrf_token" HTTP cookie with their request.
func CSRF(next http.Handler) http.Handler {
middleware := func(w http.ResponseWriter, r *http.Request) {
// All requests: verify they have a CSRF cookie, create one if not.
var token string
cookie, err := r.Cookie(constants.CSRFCookieName)
if err == nil {
token = cookie.Value
}
// Generate a token cookie if not found.
if len(token) < 8 || err != nil {
token = uuid.NewV4().String()
cookie = &http.Cookie{
Name: constants.CSRFCookieName,
Value: token,
}
http.SetCookie(w, cookie)
}
2019-11-27 00:54:02 +00:00
// Add the CSRF token to the request context. This makes it immediately
// available on FIRST page load, when the cookie hasn't been sent back
// from the browser yet.
ctx := context.WithValue(r.Context(), session.CSRFKey, token)
// POST requests: verify token from form parameter.
if r.Method == http.MethodPost {
compare := r.FormValue(constants.CSRFFormName)
if compare != token {
responses.Panic(w, http.StatusForbidden, "CSRF token failure.")
return
}
}
2019-11-27 00:54:02 +00:00
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(middleware)
}