140 lines
3.4 KiB
Go
140 lines
3.4 KiB
Go
|
package ratelimit
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"git.kirsle.net/apps/gophertype/pkg/cache"
|
||
|
"git.kirsle.net/apps/gophertype/pkg/constants"
|
||
|
)
|
||
|
|
||
|
// Limiter implements a Redis-backed rate limit for logins or otherwise.
|
||
|
type Limiter struct {
|
||
|
Namespace string // kind of rate limiter ("login")
|
||
|
ID interface{} // unique ID of the resource being pinged (str or ints)
|
||
|
Limit int // how many pings within the window period
|
||
|
Window time.Duration // the window period/expiration of Redis key
|
||
|
CooldownAt int // how many pings before the cooldown is enforced
|
||
|
Cooldown time.Duration // time to wait between fails
|
||
|
}
|
||
|
|
||
|
// Redis object behind the rate limiter.
|
||
|
type Data struct {
|
||
|
Pings int
|
||
|
NotBefore time.Time
|
||
|
}
|
||
|
|
||
|
// Ping the rate limiter.
|
||
|
func (l *Limiter) Ping() error {
|
||
|
var (
|
||
|
key = l.Key()
|
||
|
now = time.Now()
|
||
|
)
|
||
|
|
||
|
// Get stored data from Redis if any.
|
||
|
var data Data
|
||
|
cache.GetJSON(key, &data)
|
||
|
|
||
|
// Are we cooling down?
|
||
|
if now.Before(data.NotBefore) {
|
||
|
return fmt.Errorf(
|
||
|
"You are doing that too often. Please wait %s before trying again",
|
||
|
FormatDurationCoarse(data.NotBefore.Sub(now)),
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// Increment the ping count.
|
||
|
data.Pings++
|
||
|
|
||
|
// Have we hit the wall?
|
||
|
if data.Pings >= l.Limit {
|
||
|
return fmt.Errorf(
|
||
|
"You have hit the rate limit; please wait the full %s before trying again",
|
||
|
FormatDurationCoarse(l.Window),
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// Are we throttled?
|
||
|
if l.CooldownAt > 0 && data.Pings > l.CooldownAt {
|
||
|
data.NotBefore = now.Add(l.Cooldown)
|
||
|
if err := cache.SetJSON(key, data, l.Window); err != nil {
|
||
|
return fmt.Errorf("Couldn't set Redis key for rate limiter: %s", err)
|
||
|
}
|
||
|
return fmt.Errorf(
|
||
|
"Please wait %s before trying again. You have %d more attempt(s) remaining before you will be locked "+
|
||
|
"out for %s",
|
||
|
FormatDurationCoarse(l.Cooldown),
|
||
|
l.Limit-data.Pings,
|
||
|
FormatDurationCoarse(l.Window),
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// Save their ping count to Redis.
|
||
|
if err := cache.SetJSON(key, data, l.Window); err != nil {
|
||
|
return fmt.Errorf("Couldn't set Redis key for rate limiter: %s", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Clear the rate limiter, cleaning up the Redis key (e.g., after successful login).
|
||
|
func (l *Limiter) Clear() error {
|
||
|
return cache.Delete(l.Key())
|
||
|
}
|
||
|
|
||
|
// Key formats the Redis key.
|
||
|
func (l *Limiter) Key() string {
|
||
|
var str string
|
||
|
switch t := l.ID.(type) {
|
||
|
case int:
|
||
|
str = fmt.Sprintf("%d", t)
|
||
|
case uint64:
|
||
|
str = fmt.Sprintf("%d", t)
|
||
|
case int64:
|
||
|
str = fmt.Sprintf("%d", t)
|
||
|
case uint32:
|
||
|
str = fmt.Sprintf("%d", t)
|
||
|
case int32:
|
||
|
str = fmt.Sprintf("%d", t)
|
||
|
default:
|
||
|
str = fmt.Sprintf("%s", t)
|
||
|
}
|
||
|
return fmt.Sprintf(constants.RateLimitRedisKey, l.Namespace, str)
|
||
|
}
|
||
|
|
||
|
// FormatDurationCoarse returns a pretty printed duration with coarse granularity.
|
||
|
func FormatDurationCoarse(duration time.Duration) string {
|
||
|
var result = func(text string, v int64) string {
|
||
|
if v == 1 {
|
||
|
text = strings.TrimSuffix(text, "s")
|
||
|
}
|
||
|
return fmt.Sprintf(text, v)
|
||
|
}
|
||
|
|
||
|
if duration.Seconds() < 60.0 {
|
||
|
return result("%d seconds", int64(duration.Seconds()))
|
||
|
}
|
||
|
|
||
|
if duration.Minutes() < 60.0 {
|
||
|
return result("%d minutes", int64(duration.Minutes()))
|
||
|
}
|
||
|
|
||
|
if duration.Hours() < 24.0 {
|
||
|
return result("%d hours", int64(duration.Hours()))
|
||
|
}
|
||
|
|
||
|
days := int64(duration.Hours() / 24)
|
||
|
if days < 30 {
|
||
|
return result("%d days", days)
|
||
|
}
|
||
|
|
||
|
months := int64(days / 30)
|
||
|
if months < 12 {
|
||
|
return result("%d months", months)
|
||
|
}
|
||
|
|
||
|
years := int64(days / 365)
|
||
|
return result("%d years", years)
|
||
|
}
|