2022-08-10 05:10:47 +00:00
|
|
|
// Package session handles user login and other cookies.
|
|
|
|
package session
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"git.kirsle.net/apps/gosocial/pkg/config"
|
|
|
|
"git.kirsle.net/apps/gosocial/pkg/log"
|
|
|
|
"git.kirsle.net/apps/gosocial/pkg/models"
|
|
|
|
"git.kirsle.net/apps/gosocial/pkg/redis"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
)
|
|
|
|
|
|
|
|
// Session cookie object that is kept server side in Redis.
|
|
|
|
type Session struct {
|
|
|
|
UUID string `json:"-"` // not stored
|
|
|
|
LoggedIn bool `json:"loggedIn"`
|
|
|
|
UserID uint64 `json:"userId,omitempty"`
|
|
|
|
Flashes []string `json:"flashes,omitempty"`
|
|
|
|
Errors []string `json:"errors,omitempty"`
|
|
|
|
LastSeen time.Time `json:"lastSeen"`
|
|
|
|
}
|
|
|
|
|
|
|
|
const (
|
|
|
|
ContextKey = "session"
|
|
|
|
CSRFKey = "csrf"
|
|
|
|
)
|
|
|
|
|
|
|
|
// New creates a blank session object.
|
|
|
|
func New() *Session {
|
|
|
|
return &Session{
|
|
|
|
UUID: uuid.New().String(),
|
|
|
|
Flashes: []string{},
|
|
|
|
Errors: []string{},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Load the session from the browser session_id token and Redis or creates a new session.
|
|
|
|
func LoadOrNew(r *http.Request) *Session {
|
|
|
|
var sess = New()
|
|
|
|
|
|
|
|
// Read the session cookie value.
|
|
|
|
cookie, err := r.Cookie(config.SessionCookieName)
|
|
|
|
if err != nil {
|
|
|
|
log.Debug("session.LoadOrNew: cookie error, new sess: %s", err)
|
|
|
|
return sess
|
|
|
|
}
|
|
|
|
|
|
|
|
// Look up this UUID in Redis.
|
|
|
|
sess.UUID = cookie.Value
|
|
|
|
key := fmt.Sprintf(config.SessionRedisKeyFormat, sess.UUID)
|
|
|
|
|
|
|
|
err = redis.Get(key, sess)
|
|
|
|
log.Error("LoadOrNew: raw from Redis: %+v", sess)
|
|
|
|
if err != nil {
|
|
|
|
log.Error("session.LoadOrNew: didn't find %s in Redis: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return sess
|
|
|
|
}
|
|
|
|
|
|
|
|
// Save the session and send a cookie header.
|
|
|
|
func (s *Session) Save(w http.ResponseWriter) {
|
|
|
|
// Roll a UUID session_id value.
|
|
|
|
if s.UUID == "" {
|
|
|
|
s.UUID = uuid.New().String()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ensure it is a valid UUID.
|
|
|
|
if _, err := uuid.Parse(s.UUID); err != nil {
|
|
|
|
log.Error("Session.Save: got an invalid UUID session_id: %s", err)
|
|
|
|
s.UUID = uuid.New().String()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ping last seen.
|
|
|
|
s.LastSeen = time.Now()
|
|
|
|
|
|
|
|
// Save their session object in Redis.
|
|
|
|
key := fmt.Sprintf(config.SessionRedisKeyFormat, s.UUID)
|
|
|
|
if err := redis.Set(key, s, config.SessionCookieMaxAge*time.Second); err != nil {
|
|
|
|
log.Error("Session.Save: couldn't write to Redis: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
cookie := &http.Cookie{
|
|
|
|
Name: config.SessionCookieName,
|
|
|
|
Value: s.UUID,
|
|
|
|
MaxAge: config.SessionCookieMaxAge,
|
|
|
|
HttpOnly: true,
|
|
|
|
}
|
|
|
|
http.SetCookie(w, cookie)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get the session from the current HTTP request context.
|
|
|
|
func Get(r *http.Request) *Session {
|
|
|
|
if r == nil {
|
|
|
|
panic("session.Get: http.Request is required")
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx := r.Context()
|
|
|
|
if sess, ok := ctx.Value(ContextKey).(*Session); ok {
|
|
|
|
return sess
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the session isn't on the request, it means I broke something.
|
|
|
|
log.Error("session.Get(): didn't find session in request context!")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// ReadFlashes returns and clears the Flashes and Errors for this session.
|
|
|
|
func (s *Session) ReadFlashes(w http.ResponseWriter) (flashes, errors []string) {
|
|
|
|
flashes = s.Flashes
|
|
|
|
errors = s.Errors
|
|
|
|
s.Flashes = []string{}
|
|
|
|
s.Errors = []string{}
|
|
|
|
if len(flashes)+len(errors) > 0 {
|
|
|
|
s.Save(w)
|
|
|
|
}
|
|
|
|
return flashes, errors
|
|
|
|
}
|
|
|
|
|
|
|
|
// Flash adds a transient message to the user's session to show on next page load.
|
|
|
|
func Flash(w http.ResponseWriter, r *http.Request, msg string, args ...interface{}) {
|
|
|
|
sess := Get(r)
|
|
|
|
sess.Flashes = append(sess.Flashes, fmt.Sprintf(msg, args...))
|
|
|
|
sess.Save(w)
|
|
|
|
}
|
|
|
|
|
|
|
|
// FlashError adds a transient error message to the session.
|
|
|
|
func FlashError(w http.ResponseWriter, r *http.Request, msg string, args ...interface{}) {
|
|
|
|
sess := Get(r)
|
|
|
|
sess.Errors = append(sess.Flashes, fmt.Sprintf(msg, args...))
|
|
|
|
sess.Save(w)
|
|
|
|
}
|
|
|
|
|
|
|
|
// LoginUser marks a session as logged in to an account.
|
|
|
|
func LoginUser(w http.ResponseWriter, r *http.Request, u *models.User) error {
|
|
|
|
if u == nil || u.ID == 0 {
|
|
|
|
return errors.New("not a valid user account")
|
|
|
|
}
|
|
|
|
|
|
|
|
sess := Get(r)
|
|
|
|
sess.LoggedIn = true
|
|
|
|
sess.UserID = u.ID
|
|
|
|
sess.Save(w)
|
|
|
|
|
2022-08-11 03:59:59 +00:00
|
|
|
// Ping the user's last login time.
|
|
|
|
u.LastLoginAt = time.Now()
|
|
|
|
return u.Save()
|
2022-08-10 05:10:47 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// LogoutUser signs a user out.
|
|
|
|
func LogoutUser(w http.ResponseWriter, r *http.Request) {
|
|
|
|
sess := Get(r)
|
|
|
|
sess.LoggedIn = false
|
|
|
|
sess.UserID = 0
|
|
|
|
sess.Save(w)
|
|
|
|
}
|