package middleware import ( "context" "net/http" "time" "git.kirsle.net/apps/gophertype/pkg/constants" "git.kirsle.net/apps/gophertype/pkg/responses" "git.kirsle.net/apps/gophertype/pkg/session" uuid "github.com/satori/go.uuid" ) // CSRF prevents Cross-Site Request Forgery. // All "POST" requests are required to have an "_csrf" variable passed in which // matches the "csrf_token" HTTP cookie with their request. func CSRF(next http.Handler) http.Handler { middleware := func(w http.ResponseWriter, r *http.Request) { // All requests: verify they have a CSRF cookie, create one if not. var token string cookie, err := r.Cookie(constants.CSRFCookieName) if err == nil { token = cookie.Value } // Generate a token cookie if not found. if len(token) < 8 || err != nil { token = uuid.NewV4().String() cookie = &http.Cookie{ Name: constants.CSRFCookieName, Value: token, Expires: time.Now().Add(24 * time.Hour), } http.SetCookie(w, cookie) } // Add the CSRF token to the request context. This makes it immediately // available on FIRST page load, when the cookie hasn't been sent back // from the browser yet. ctx := context.WithValue(r.Context(), session.CSRFKey, token) // POST requests: verify token from form parameter. if r.Method == http.MethodPost { compare := r.FormValue(constants.CSRFFormName) if compare != token { responses.Panic(w, http.StatusForbidden, "CSRF token failure.") return } } next.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(middleware) }