diff --git a/rate.go b/rate.go index 1cfc36e..81fe8ef 100644 --- a/rate.go +++ b/rate.go @@ -9,7 +9,8 @@ import ( "github.com/redis/go-redis/v9" ) -const redisPrefix = "rate:" +// DefaultRedisPrefix is the default prefix for redis keys +const DefaultRedisPrefix = "rate:" type rediser interface { Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd @@ -76,14 +77,29 @@ func PerHour(rate int) Limit { // Limiter controls how frequently events are allowed to happen. type Limiter struct { - rdb rediser + rdb rediser + keyPrefix string +} + +type Options func(l *Limiter) + +func WithKeyPrefix(prefix string) Options { + return func(l *Limiter) { + l.keyPrefix = prefix + } } // NewLimiter returns a new Limiter. -func NewLimiter(rdb rediser) *Limiter { - return &Limiter{ - rdb: rdb, +func NewLimiter(rdb rediser, opts ...Options) *Limiter { + l := &Limiter{ + rdb: rdb, + keyPrefix: DefaultRedisPrefix, + } + for _, opt := range opts { + opt(l) } + + return l } // Allow is a shortcut for AllowN(ctx, key, limit, 1). @@ -99,7 +115,7 @@ func (l Limiter) AllowN( n int, ) (*Result, error) { values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n} - v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result() + v, err := allowN.Run(ctx, l.rdb, []string{l.keyPrefix + key}, values...).Result() if err != nil { return nil, err } @@ -135,7 +151,7 @@ func (l Limiter) AllowAtMost( n int, ) (*Result, error) { values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n} - v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result() + v, err := allowAtMost.Run(ctx, l.rdb, []string{l.keyPrefix + key}, values...).Result() if err != nil { return nil, err } @@ -164,7 +180,7 @@ func (l Limiter) AllowAtMost( // Reset gets a key and reset all limitations and previous usages func (l *Limiter) Reset(ctx context.Context, key string) error { - return l.rdb.Del(ctx, redisPrefix+key).Err() + return l.rdb.Del(ctx, l.keyPrefix+key).Err() } func dur(f float64) time.Duration { diff --git a/rate_test.go b/rate_test.go index 19ede04..303e3f5 100644 --- a/rate_test.go +++ b/rate_test.go @@ -11,61 +11,75 @@ import ( "github.com/go-redis/redis_rate/v10" ) -func rateLimiter() *redis_rate.Limiter { +func rateLimiterWithKeyPrefix(keyPrefix string) *redis_rate.Limiter { ring := redis.NewRing(&redis.RingOptions{ Addrs: map[string]string{"server0": ":6379"}, }) if err := ring.FlushDB(context.TODO()).Err(); err != nil { panic(err) } - return redis_rate.NewLimiter(ring) + return redis_rate.NewLimiter(ring, redis_rate.WithKeyPrefix(keyPrefix)) +} + +func rateLimiter() *redis_rate.Limiter { + return rateLimiterWithKeyPrefix(redis_rate.DefaultRedisPrefix) } func TestAllow(t *testing.T) { ctx := context.Background() - l := rateLimiter() - - limit := redis_rate.PerSecond(10) - require.Equal(t, limit.String(), "10 req/s (burst 10)") - require.False(t, limit.IsZero()) - - res, err := l.Allow(ctx, "test_id", limit) - require.Nil(t, err) - require.Equal(t, res.Allowed, 1) - require.Equal(t, res.Remaining, 9) - require.Equal(t, res.RetryAfter, time.Duration(-1)) - require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) - - err = l.Reset(ctx, "test_id") - require.Nil(t, err) - res, err = l.Allow(ctx, "test_id", limit) - require.Nil(t, err) - require.Equal(t, res.Allowed, 1) - require.Equal(t, res.Remaining, 9) - require.Equal(t, res.RetryAfter, time.Duration(-1)) - require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) - - res, err = l.AllowN(ctx, "test_id", limit, 2) - require.Nil(t, err) - require.Equal(t, res.Allowed, 2) - require.Equal(t, res.Remaining, 7) - require.Equal(t, res.RetryAfter, time.Duration(-1)) - require.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond)) - - res, err = l.AllowN(ctx, "test_id", limit, 7) - require.Nil(t, err) - require.Equal(t, res.Allowed, 7) - require.Equal(t, res.Remaining, 0) - require.Equal(t, res.RetryAfter, time.Duration(-1)) - require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) - - res, err = l.AllowN(ctx, "test_id", limit, 1000) - require.Nil(t, err) - require.Equal(t, res.Allowed, 0) - require.Equal(t, res.Remaining, 0) - require.InDelta(t, res.RetryAfter, 99*time.Second, float64(time.Second)) - require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) + tcs := []struct { + keyPrefix string + }{ + {keyPrefix: ""}, + {keyPrefix: redis_rate.DefaultRedisPrefix}, + {keyPrefix: "test:"}, + } + for _, tc := range tcs { + t.Run(tc.keyPrefix, func(t *testing.T) { + l := rateLimiterWithKeyPrefix(tc.keyPrefix) + limit := redis_rate.PerSecond(10) + require.Equal(t, limit.String(), "10 req/s (burst 10)") + require.False(t, limit.IsZero()) + + res, err := l.Allow(ctx, "test_id", limit) + require.Nil(t, err) + require.Equal(t, res.Allowed, 1) + require.Equal(t, res.Remaining, 9) + require.Equal(t, res.RetryAfter, time.Duration(-1)) + require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) + + err = l.Reset(ctx, "test_id") + require.Nil(t, err) + res, err = l.Allow(ctx, "test_id", limit) + require.Nil(t, err) + require.Equal(t, res.Allowed, 1) + require.Equal(t, res.Remaining, 9) + require.Equal(t, res.RetryAfter, time.Duration(-1)) + require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) + + res, err = l.AllowN(ctx, "test_id", limit, 2) + require.Nil(t, err) + require.Equal(t, res.Allowed, 2) + require.Equal(t, res.Remaining, 7) + require.Equal(t, res.RetryAfter, time.Duration(-1)) + require.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond)) + + res, err = l.AllowN(ctx, "test_id", limit, 7) + require.Nil(t, err) + require.Equal(t, res.Allowed, 7) + require.Equal(t, res.Remaining, 0) + require.Equal(t, res.RetryAfter, time.Duration(-1)) + require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) + + res, err = l.AllowN(ctx, "test_id", limit, 1000) + require.Nil(t, err) + require.Equal(t, res.Allowed, 0) + require.Equal(t, res.Remaining, 0) + require.InDelta(t, res.RetryAfter, 99*time.Second, float64(time.Second)) + require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) + }) + } } func TestAllowN_IncrementZero(t *testing.T) {