"""Redis-based rate limiting middleware for production scalability."""
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import redis
import time
from typing import Optional

from shared.config import settings
from shared.logging_config import setup_logging

logger = setup_logging("redis_rate_limit")


class RedisRateLimitMiddleware(BaseHTTPMiddleware):
    """Redis-backed rate limiting using sliding window algorithm."""
    
    def __init__(self, app, requests_per_minute: int = 60, redis_url: Optional[str] = None):
        super().__init__(app)
        self.requests_per_minute = requests_per_minute
        self.window_size = 60  # 1 minute in seconds
        
        # Initialize Redis connection
        try:
            if redis_url is None:
                redis_url = settings.redis_url
            
            if redis_url:
                self.redis_client = redis.from_url(redis_url, decode_responses=True)
                self.redis_enabled = True
                logger.info(f"Redis rate limiting enabled: {requests_per_minute} req/min")
            else:
                self.redis_enabled = False
                logger.warning("Redis URL not configured, rate limiting disabled")
        except Exception as e:
            logger.error(f"Failed to connect to Redis: {e}")
            self.redis_enabled = False
    
    async def dispatch(self, request: Request, call_next):
        """Process request with Redis-based rate limiting."""
        if not self.redis_enabled:
            # Fall back to no rate limiting if Redis unavailable
            return await call_next(request)
        
        # Skip rate limiting for health checks
        if request.url.path == "/health":
            return await call_next(request)
        
        # Get client identifier (IP + optional user ID)
        client_ip = request.client.host
        user_id = self._get_user_id_from_request(request)
        client_key = f"rate_limit:{client_ip}"
        if user_id:
            client_key = f"rate_limit:user:{user_id}"
        
        now = time.time()
        window_start = now - self.window_size
        
        try:
            # Use sliding window algorithm with sorted sets
            pipe = self.redis_client.pipeline()
            
            # Remove old entries outside the window
            pipe.zremrangebyscore(client_key, 0, window_start)
            
            # Count requests in current window
            pipe.zcard(client_key)
            
            # Add current request
            pipe.zadd(client_key, {str(now): now})
            
            # Set expiration to cleanup old keys
            pipe.expire(client_key, self.window_size + 10)
            
            results = pipe.execute()
            request_count = results[1]  # Count before adding current request
            
            # Check if limit exceeded
            if request_count >= self.requests_per_minute:
                # Calculate retry after
                oldest_in_window = self.redis_client.zrange(client_key, 0, 0, withscores=True)
                if oldest_in_window:
                    oldest_timestamp = oldest_in_window[0][1]
                    retry_after = int(self.window_size - (now - oldest_timestamp)) + 1
                else:
                    retry_after = self.window_size
                
                return JSONResponse(
                    status_code=status.HTTP_429_TOO_MANY_REQUESTS,
                    content={
                        "detail": "Rate limit exceeded. Please try again later.",
                        "retry_after": retry_after
                    },
                    headers={"Retry-After": str(retry_after)}
                )
            
            # Process request
            response = await call_next(request)
            
            # Add rate limit headers
            remaining = max(0, self.requests_per_minute - request_count - 1)
            response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
            response.headers["X-RateLimit-Remaining"] = str(remaining)
            response.headers["X-RateLimit-Reset"] = str(int(now + self.window_size))
            
            return response
            
        except redis.RedisError as e:
            logger.error(f"Redis error during rate limiting: {e}")
            # Fall back to allowing the request if Redis fails
            return await call_next(request)
    
    def _get_user_id_from_request(self, request: Request) -> Optional[str]:
        """Extract user ID from JWT token if present."""
        try:
            auth_header = request.headers.get("Authorization")
            if auth_header and auth_header.startswith("Bearer "):
                # For simplicity, we'll use IP-based limiting
                # In production, you could decode the JWT to get user_id
                pass
        except:
            pass
        return None


def get_redis_client() -> Optional[redis.Redis]:
    """Get Redis client instance.
    
    Returns:
        Redis client or None if not configured
    """
    try:
        if settings.redis_url:
            return redis.from_url(settings.redis_url, decode_responses=True)
    except Exception as e:
        logger.error(f"Failed to create Redis client: {e}")
    return None
