package ratelimit import ( "errors" "fmt" "time" "github.com/kirsle/blog/jsondb/caches/redis" ) // Limiter implements a Redis-backed rate limit for logins or otherwise. type Limiter struct { Namespace string // kind of rate limiter ("login") ID interface{} // unique ID of the resource being pinged (str or ints) Limit int // how many pings within the window period Window int // the window period/expiration of Redis key CooldownAt int // how many pings before the cooldown is enforced Cooldown int // time to wait between fails } // The active Redis cache given by the webapp. var Cache *redis.Redis // Redis object behind the rate limiter. type Data struct { Pings int NotBefore time.Time } // Ping the rate limiter. func (l *Limiter) Ping() error { if Cache == nil { return errors.New("redis not ready") } var ( key = l.Key() now = time.Now() ) // Get stored data from Redis if any. var data Data Cache.GetJSON(key, &data) // Are we cooling down? if now.Before(data.NotBefore) { return fmt.Errorf( "You are doing that too often.", ) } // Increment the ping count. data.Pings++ // Have we hit the wall? if data.Pings >= l.Limit { return fmt.Errorf( "You have hit the rate limit; come back later.", ) } // Are we throttled? if l.CooldownAt > 0 && data.Pings > l.CooldownAt { data.NotBefore = now.Add(time.Duration(l.Cooldown) * time.Second) if err := Cache.SetJSON(key, data, l.Window); err != nil { return fmt.Errorf("Couldn't set Redis key for rate limiter: %s", err) } return fmt.Errorf( "Please wait %ds before trying again. You have %d more attempt(s) remaining before you will be locked "+ "out for %ds.", l.Cooldown, l.Limit-data.Pings, l.Window, ) } // Save their ping count to Redis. if err := Cache.SetJSON(key, data, l.Window); err != nil { return fmt.Errorf("Couldn't set Redis key for rate limiter: %s", err) } return nil } // Clear the rate limiter, cleaning up the Redis key (e.g., after successful login). func (l *Limiter) Clear() { Cache.Delete(l.Key()) } // Key formats the Redis key. func (l *Limiter) Key() string { var str string switch t := l.ID.(type) { case int: str = fmt.Sprintf("%d", t) case uint64: str = fmt.Sprintf("%d", t) case int64: str = fmt.Sprintf("%d", t) case uint32: str = fmt.Sprintf("%d", t) case int32: str = fmt.Sprintf("%d", t) default: str = fmt.Sprintf("%s", t) } return fmt.Sprintf("rlimit/%s/%s", l.Namespace, str) }