package ratelimit import ( "fmt" "strings" "time" "git.kirsle.net/apps/gophertype/pkg/cache" "git.kirsle.net/apps/gophertype/pkg/constants" ) // Limiter implements a Redis-backed rate limit for logins or otherwise. type Limiter struct { Namespace string // kind of rate limiter ("login") ID interface{} // unique ID of the resource being pinged (str or ints) Limit int // how many pings within the window period Window time.Duration // the window period/expiration of Redis key CooldownAt int // how many pings before the cooldown is enforced Cooldown time.Duration // time to wait between fails } // Redis object behind the rate limiter. type Data struct { Pings int NotBefore time.Time } // Ping the rate limiter. func (l *Limiter) Ping() error { var ( key = l.Key() now = time.Now() ) // Get stored data from Redis if any. var data Data cache.GetJSON(key, &data) // Are we cooling down? if now.Before(data.NotBefore) { return fmt.Errorf( "You are doing that too often. Please wait %s before trying again", FormatDurationCoarse(data.NotBefore.Sub(now)), ) } // Increment the ping count. data.Pings++ // Have we hit the wall? if data.Pings >= l.Limit { return fmt.Errorf( "You have hit the rate limit; please wait the full %s before trying again", FormatDurationCoarse(l.Window), ) } // Are we throttled? if l.CooldownAt > 0 && data.Pings > l.CooldownAt { data.NotBefore = now.Add(l.Cooldown) if err := cache.SetJSON(key, data, l.Window); err != nil { return fmt.Errorf("Couldn't set Redis key for rate limiter: %s", err) } return fmt.Errorf( "Please wait %s before trying again. You have %d more attempt(s) remaining before you will be locked "+ "out for %s", FormatDurationCoarse(l.Cooldown), l.Limit-data.Pings, FormatDurationCoarse(l.Window), ) } // Save their ping count to Redis. if err := cache.SetJSON(key, data, l.Window); err != nil { return fmt.Errorf("Couldn't set Redis key for rate limiter: %s", err) } return nil } // Clear the rate limiter, cleaning up the Redis key (e.g., after successful login). func (l *Limiter) Clear() error { return cache.Delete(l.Key()) } // Key formats the Redis key. func (l *Limiter) Key() string { var str string switch t := l.ID.(type) { case int: str = fmt.Sprintf("%d", t) case uint64: str = fmt.Sprintf("%d", t) case int64: str = fmt.Sprintf("%d", t) case uint32: str = fmt.Sprintf("%d", t) case int32: str = fmt.Sprintf("%d", t) default: str = fmt.Sprintf("%s", t) } return fmt.Sprintf(constants.RateLimitRedisKey, l.Namespace, str) } // FormatDurationCoarse returns a pretty printed duration with coarse granularity. func FormatDurationCoarse(duration time.Duration) string { var result = func(text string, v int64) string { if v == 1 { text = strings.TrimSuffix(text, "s") } return fmt.Sprintf(text, v) } if duration.Seconds() < 60.0 { return result("%d seconds", int64(duration.Seconds())) } if duration.Minutes() < 60.0 { return result("%d minutes", int64(duration.Minutes())) } if duration.Hours() < 24.0 { return result("%d hours", int64(duration.Hours())) } days := int64(duration.Hours() / 24) if days < 30 { return result("%d days", days) } months := int64(days / 30) if months < 12 { return result("%d months", months) } years := int64(days / 365) return result("%d years", years) }