package ratelimit
import (
"sync"
"time"
)
type RateLimiter struct {
mu sync.Mutex
tokens float64
maxTokens float64
refillRate float64
lastRefill time.Time
}
func NewRateLimiter(rate float64, burst int) *RateLimiter {
return &RateLimiter{
tokens: float64(burst),
maxTokens: float64(burst),
refillRate: rate,
lastRefill: time.Now(),
}
}
func (rl *RateLimiter) refill() {
now := time.Now()
elapsed := now.Sub(rl.lastRefill).Seconds()
tokensToAdd := elapsed * rl.refillRate
rl.mu.Lock()
defer rl.mu.Unlock()
rl.tokens = min(rl.tokens+tokensToAdd, rl.maxTokens)
rl.lastRefill = now
}
func (rl *RateLimiter) Allow() bool {
rl.refill()
rl.mu.Lock()
defer rl.mu.Unlock()
if rl.tokens >= 1.0 {
rl.tokens -= 1.0
return true
}
return false
}
func (rl *RateLimiter) Wait(ctx context.Context) error {
for !rl.Allow() {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second / time.Duration(rl.refillRate)):
}
}
return nil
}
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}