package ratelimit

import (
	"sync"
	"time"
)

type RateLimiter struct {
	mu          sync.Mutex
	tokens      float64
	maxTokens   float64
	refillRate  float64
	lastRefill  time.Time
}

func NewRateLimiter(rate float64, burst int) *RateLimiter {
	return &RateLimiter{
		tokens:     float64(burst),
		maxTokens:  float64(burst),
		refillRate: rate,
		lastRefill: time.Now(),
	}
}

func (rl *RateLimiter) refill() {
	now := time.Now()
	elapsed := now.Sub(rl.lastRefill).Seconds()
	tokensToAdd := elapsed * rl.refillRate
	
	rl.mu.Lock()
	defer rl.mu.Unlock()
	
	rl.tokens = min(rl.tokens+tokensToAdd, rl.maxTokens)
	rl.lastRefill = now
}

func (rl *RateLimiter) Allow() bool {
	rl.refill()
	
	rl.mu.Lock()
	defer rl.mu.Unlock()
	
	if rl.tokens >= 1.0 {
		rl.tokens -= 1.0
		return true
	}
	return false
}

func (rl *RateLimiter) Wait(ctx context.Context) error {
	for !rl.Allow() {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-time.After(time.Second / time.Duration(rl.refillRate)):
		}
	}
	return nil
}

func min(a, b float64) float64 {
	if a < b {
		return a
	}
	return b
}