gophertype/pkg/session/sessions.go

89 lines
2.4 KiB
Go
Raw Normal View History

2019-11-26 03:55:28 +00:00
package session
import (
"context"
"fmt"
"net/http"
"time"
"git.kirsle.net/apps/gophertype/pkg/console"
2019-11-26 03:55:28 +00:00
"github.com/gorilla/sessions"
)
// Store holds your cookie store information.
var Store sessions.Store
// SetSecretKey initializes a session cookie store with the secret key.
func SetSecretKey(keyPairs ...[]byte) {
Store = sessions.NewCookieStore(keyPairs...)
}
// Middleware gets the Gorilla session store and makes it available on the
// Request context.
//
// Middleware is the first custom middleware applied, so it takes the current
// datetime to make available later in the request and stores it on the request
// context.
func Middleware(next http.Handler) http.Handler {
middleware := func(w http.ResponseWriter, r *http.Request) {
// Set the HTML content-type header by default until overridden by a handler.
w.Header().Set("Content-Type", "text/html; charset=utf-8")
2019-11-26 03:55:28 +00:00
// Store the current datetime on the request context.
ctx := context.WithValue(r.Context(), StartTimeKey, time.Now())
// Get the Gorilla session and make it available in the request context.
session, _ := Store.Get(r, "session")
ctx = context.WithValue(ctx, SessionKey, session)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(middleware)
}
// Get returns the current request's session.
func Get(r *http.Request) *sessions.Session {
if r == nil {
panic("Session(*http.Request) with a nil argument!?")
}
ctx := r.Context()
if session, ok := ctx.Value(SessionKey).(*sessions.Session); ok {
2019-11-26 03:55:28 +00:00
return session
}
// If the session wasn't on the request, it means I broke something.
console.Warn(
"Session(): didn't find session in request context! Getting it " +
2019-11-26 03:55:28 +00:00
"from the session store instead.",
)
session, _ := Store.Get(r, "session")
return session
}
// Flash adds a flashed message to the session for the next template rendering.
func Flash(w http.ResponseWriter, r *http.Request, msg string, args ...interface{}) {
sess := Get(r)
var flashes []string
if v, ok := sess.Values["flashes"].([]string); ok {
flashes = v
}
flashes = append(flashes, fmt.Sprintf(msg, args...))
sess.Values["flashes"] = flashes
sess.Save(r, w)
}
// GetFlashes returns all the flashes from the session and clears the queue.
func GetFlashes(w http.ResponseWriter, r *http.Request) []string {
sess := Get(r)
if flashes, ok := sess.Values["flashes"].([]string); ok {
sess.Values["flashes"] = []string{}
sess.Save(r, w)
return flashes
}
return []string{}
}