
"""Main API Gateway - Updated with orchestration integration."""
from fastapi import FastAPI, Depends, HTTPException, status, BackgroundTasks, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
import json
import redis
from datetime import datetime, timedelta
from typing import List, Optional
import sys
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from shared.database import get_db, init_db
from shared.models import User, APIKey, Project, Subscription, RefreshToken
from shared.schemas import (
    UserCreate, UserResponse, Token, ProjectCreate, ProjectResponse,
    MessageResponse, HealthResponse, ProjectUpdate, UserUpdate, AdminStats,
    SubscriptionResponse, CheckoutSessionCreate, CheckoutSessionResponse,
    SubscriptionCancel, EmailVerificationRequest, PasswordResetRequest,
    PasswordResetConfirm, RefreshTokenRequest, LoginRequest
)
from shared.auth import (
    verify_password, get_password_hash, create_access_token, generate_api_key,
    create_refresh_token, verify_refresh_token
)
from shared.dependencies import get_current_user, get_current_admin_user
from shared.config import settings, ensure_directories
from shared.logging_config import setup_logging
from shared.security import setup_security_middleware
from shared.validation import InputValidator
from shared.error_handlers import setup_error_handlers
from services.orchestration.workflow import get_workflow
from services.payments.stripe_service import StripeService
from services.email.email_service import EmailService, generate_token
import stripe

# Set up logging
logger = setup_logging("gateway")

# Create FastAPI app
app = FastAPI(
    title=settings.app_name,
    description="AI-Powered Content Creation Platform API",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Setup error handlers
setup_error_handlers(app)

# Set up security middleware (CORS, rate limiting, security headers)
allowed_origins = [
    "http://localhost:3000",
    "http://localhost:5173",  # Vite dev server
    "http://localhost:8000",
    "http://localhost",
]

# In production, add your domain
if not settings.debug:
    # allowed_origins = ["https://yourdomain.com"]
    pass

setup_security_middleware(
    app,
    rate_limit=100 if settings.debug else 60,  # Higher limit in dev
    allowed_origins=allowed_origins
)


@app.on_event("startup")
async def startup_event():
    """Initialize application on startup."""
    logger.info(f"Starting {settings.app_name}...")
    ensure_directories()
    init_db()
    logger.info("Application started successfully")


@app.on_event("shutdown")
async def shutdown_event():
    """Clean up on shutdown."""
    logger.info("Shutting down application...")


# ===== Health Check =====
@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint."""
    return HealthResponse(
        status="healthy",
        service="gateway",
        timestamp=datetime.utcnow(),
        version="1.0.0"
    )


# ===== Authentication Endpoints =====
@app.post("/auth/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate, db: Session = Depends(get_db)):
    """Register a new user."""
    # Validate inputs
    email = InputValidator.validate_email(user_data.email)
    username = InputValidator.validate_username(user_data.username)
    password = InputValidator.validate_password(user_data.password)
    
    # Check if user exists
    existing_user = db.query(User).filter(
        (User.email == email) | (User.username == username)
    ).first()
    
    if existing_user:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Email or username already registered"
        )
    
    # Create user
    user = User(
        email=email,
        username=username,
        hashed_password=get_password_hash(password),
        is_active=True,
        tier="free"
    )
    
    db.add(user)
    db.commit()
    db.refresh(user)
    
    logger.info(f"New user registered: {user.email}")
    
    return user


@app.post("/auth/login", response_model=Token)
async def login(credentials: LoginRequest, db: Session = Depends(get_db)):
    """Authenticate user and return JWT token."""
    # Find user
    user = db.query(User).filter(
        (User.username == credentials.username) | (User.email == credentials.username)
    ).first()
    
    if not user or not verify_password(credentials.password, user.hashed_password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    
    if not user.is_active:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Inactive user account"
        )
    
    # Create access token
    access_token = create_access_token(data={"sub": user.id})
    
    # Create refresh token
    refresh_token_str = create_refresh_token()
    refresh_token = RefreshToken(
        user_id=user.id,
        token=refresh_token_str,
        expires_at=datetime.utcnow() + timedelta(days=30)
    )
    db.add(refresh_token)
    db.commit()
    
    logger.info(f"User logged in: {user.email}")
    
    return Token(
        access_token=access_token,
        refresh_token=refresh_token_str,
        token_type="bearer",
        expires_in=settings.jwt_expiration_hours * 3600
    )


@app.get("/auth/me", response_model=UserResponse)
async def get_current_user_info(current_user: User = Depends(get_current_user)):
    """Get current user information."""
    return current_user


@app.patch("/auth/me", response_model=UserResponse)
async def update_current_user(
    user_update: UserUpdate,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Update current user profile."""
    if user_update.email:
        # Check uniqueness
        existing = db.query(User).filter(User.email == user_update.email).first()
        if existing and existing.id != current_user.id:
            raise HTTPException(status_code=400, detail="Email already taken")
        current_user.email = user_update.email
        
    if user_update.username:
        # Check uniqueness
        existing = db.query(User).filter(User.username == user_update.username).first()
        if existing and existing.id != current_user.id:
            raise HTTPException(status_code=400, detail="Username already taken")
        current_user.username = user_update.username
        
    if user_update.password:
        current_user.hashed_password = get_password_hash(user_update.password)
        
    db.commit()
    db.refresh(current_user)
    return current_user


# ===== Admin Endpoints =====
@app.get("/admin/users", response_model=List[UserResponse])
async def list_users(
    skip: int = 0,
    limit: int = 50,
    current_user: User = Depends(get_current_admin_user),
    db: Session = Depends(get_db)
):
    """List all users (Admin only) with pagination."""
    # Enforce maximum limit
    if limit > 50:
        limit = 50
    
    users = db.query(User).offset(skip).limit(limit).all()
    total = db.query(User).count()
    
    # Note: For a proper API, we'd return total count in headers
    # response.headers["X-Total-Count"] = str(total)
    
    return users


@app.patch("/admin/users/{user_id}", response_model=UserResponse)
async def update_user_admin(
    user_id: str,
    user_update: UserUpdate,
    current_user: User = Depends(get_current_admin_user),
    db: Session = Depends(get_db)
):
    """Update any user (Admin only)."""
    user = db.query(User).filter(User.id == user_id).first()
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
        
    if user_update.is_active is not None:
        user.is_active = user_update.is_active
    if user_update.is_admin is not None:
        user.is_admin = user_update.is_admin
    if user_update.tier is not None:
        user.tier = user_update.tier
        
    db.commit()
    db.refresh(user)
    return user


@app.get("/admin/stats", response_model=AdminStats)
async def get_admin_stats(
    current_user: User = Depends(get_current_admin_user),
    db: Session = Depends(get_db)
):
    """Get system statistics (Admin only)."""
    return AdminStats(
        total_users=db.query(User).count(),
        active_users=db.query(User).filter(User.is_active == True).count(),
        total_projects=db.query(Project).count(),
        completed_projects=db.query(Project).filter(Project.status == "completed").count(),
        failed_projects=db.query(Project).filter(Project.status == "failed").count(),
        total_api_keys=db.query(APIKey).count()
    )


# ===== API Key Management =====
@app.post("/api-keys", response_model=dict)
async def create_api_key(
    name: str,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Create a new API key for the current user."""
    key = generate_api_key()
    
    api_key = APIKey(
        user_id=current_user.id,
        key=key,
        name=name,
        is_active=True
    )
    
    db.add(api_key)
    db.commit()
    db.refresh(api_key)
    
    logger.info(f"API key created for user {current_user.email}: {name}")
    
    return {
        "id": api_key.id,
        "key": key,  # Only returned on creation
        "name": name,
        "created_at": api_key.created_at
    }


@app.get("/api-keys", response_model=List[dict])
async def list_api_keys(
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """List all API keys for current user (without showing the actual keys)."""
    keys = db.query(APIKey).filter(APIKey.user_id == current_user.id).all()
    
    return [
        {
            "id": k.id,
            "name": k.name,
            "is_active": k.is_active,
            "created_at": k.created_at,
            "last_used_at": k.last_used_at
        }
        for k in keys
    ]


@app.delete("/api-keys/{key_id}", response_model=MessageResponse)
async def delete_api_key(
    key_id: str,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Delete an API key."""
    api_key = db.query(APIKey).filter(
        APIKey.id == key_id,
        APIKey.user_id == current_user.id
    ).first()
    
    if not api_key:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="API key not found"
        )
    
    db.delete(api_key)
    db.commit()
    
    logger.info(f"API key deleted: {key_id}")
    
    return MessageResponse(message="API key deleted successfully")


# Background task for content generation
async def generate_content_task(project_id: str):
    """Background task to generate content.
    
    Args:
        project_id: Project ID
    """
    try:
        # Get project from DB
        db = SessionLocal()
        project = db.query(Project).filter(Project.id == project_id).first()
        db.close()
        
        if not project:
            logger.error(f"Project not found: {project_id}")
            return
        
        # Get workflow orchestrator
        workflow = get_workflow()
        
        # Start content creation
        await workflow.create_content(
            project_id=project.id,
            prompt=project.prompt,
            audio_settings=project.audio_settings,
            video_settings=project.video_settings
        )
        
    except Exception as e:
        logger.error(f"Content generation failed for project {project_id}: {e}")


# ===== Project Management =====
@app.post("/projects", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED)
async def create_project(
    project_data: ProjectCreate,
    background_tasks: BackgroundTasks,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Create a new content generation project."""
    # Validate inputs
    name = InputValidator.validate_project_name(project_data.name)
    prompt = InputValidator.validate_prompt(project_data.prompt)
    
    project = Project(
        user_id=current_user.id,
        name=name,
        prompt=prompt,
        status="pending",
        audio_settings=project_data.audio_settings.model_dump() if project_data.audio_settings else None,
        video_settings=project_data.video_settings.model_dump() if project_data.video_settings else None
    )
    
    db.add(project)
    db.commit()
    db.refresh(project)
    
    # Push job to Redis queue
    try:
        r = redis.Redis(host='localhost', port=6379, db=0)
        job_data = {
            "id": str(project.id), # Ensure UUID is stringified for JSON
            "prompt": project.prompt,
            "audio_settings": project.audio_settings,
            "video_settings": project.video_settings
        }
        r.rpush('video_queue', json.dumps(job_data))
        logger.info(f"Project {project.id} pushed to Redis queue.")
    except Exception as e:
        logger.error(f"Failed to push job to Redis for project {project.id}: {e}")
        # Consider handling this error or marking project as failed immediately
    
    logger.info(f"Project created: {project.id} by user {current_user.email}")
    
    # Trigger orchestration workflow in background
    background_tasks.add_task(generate_content_task, project.id)
    
    return project


@app.get("/projects", response_model=List[ProjectResponse])
async def list_projects(
    skip: int = 0,
    limit: int = 100,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """List all projects for current user."""
    projects = db.query(Project).filter(
        Project.user_id == current_user.id
    ).order_by(Project.created_at.desc()).offset(skip).limit(limit).all()
    
    return projects


@app.get("/projects/{project_id}", response_model=ProjectResponse)
async def get_project(
    project_id: str,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Get project details."""
    project = db.query(Project).filter(
        Project.id == project_id,
        Project.user_id == current_user.id
    ).first()
    
    if not project:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Project not found"
        )
    
    return project


@app.patch("/projects/{project_id}", response_model=ProjectResponse)
async def update_project(
    project_id: str,
    project_update: ProjectUpdate,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Update project details."""
    project = db.query(Project).filter(
        Project.id == project_id,
        Project.user_id == current_user.id
    ).first()
    
    if not project:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Project not found"
        )
    
    # Update fields
    if project_update.name is not None:
        project.name = project_update.name
    if project_update.status is not None:
        project.status = project_update.status
    
    db.commit()
    db.refresh(project)
    
    return project


@app.delete("/projects/{project_id}", response_model=MessageResponse)
async def delete_project(
    project_id: str,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Delete a project."""
    project = db.query(Project).filter(
        Project.id == project_id,
        Project.user_id == current_user.id
    ).first()
    
    if not project:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Project not found"
        )
    
    db.delete(project)
    db.commit()
    
    logger.info(f"Project deleted: {project_id}")
    
    return MessageResponse(message="Project deleted successfully")


# ===== Subscription Endpoints =====
@app.post("/subscriptions/checkout", response_model=CheckoutSessionResponse)
async def create_checkout_session(
    checkout_data: CheckoutSessionCreate,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Create Stripe checkout session for subscription."""
    try:
        session_data = StripeService.create_checkout_session(
            user=current_user,
            plan=checkout_data.plan,
            success_url=checkout_data.success_url,
            cancel_url=checkout_data.cancel_url,
            db=db
        )
        return CheckoutSessionResponse(**session_data)
    except Exception as e:
        logger.error(f"Checkout session creation failed: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )


@app.get("/subscriptions/current", response_model=Optional[SubscriptionResponse])
async def get_current_subscription(
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Get current user's subscription."""
    subscription = db.query(Subscription).filter(
        Subscription.user_id == current_user.id
    ).first()
    
    return subscription


@app.post("/subscriptions/cancel", response_model=SubscriptionResponse)
async def cancel_subscription(
    cancel_data: SubscriptionCancel,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Cancel current subscription."""
    subscription = db.query(Subscription).filter(
        Subscription.user_id == current_user.id
    ).first()
    
    if not subscription:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="No active subscription found"
        )
    
    try:
        updated_sub = StripeService.cancel_subscription(
            subscription=subscription,
            immediately=cancel_data.immediately,
            db=db
        )
        return updated_sub
    except Exception as e:
        logger.error(f"Subscription cancellation failed: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )


@app.post("/webhook/stripe")
async def stripe_webhook(request: Request, db: Session = Depends(get_db)):
    """Handle Stripe webhooks with proper signature verification."""
    payload = await request.body()
    sig_header = request.headers.get("stripe-signature")
    
    # Verify webhook secret is configured
    if not settings.stripe_webhook_secret:
        logger.error("Stripe webhook secret not configured")
        raise HTTPException(status_code=500, detail="Webhook configuration error")
    
    try:
        event = stripe.Webhook.construct_event(
            payload, sig_header, settings.stripe_webhook_secret
        )
        
        StripeService.handle_webhook_event(event, db)
        
        return {"status": "success"}
    except ValueError as e:
        logger.error(f"Invalid webhook payload: {e}")
        raise HTTPException(status_code=400, detail="Invalid payload")
    except stripe.error.SignatureVerificationError as e:
        logger.error(f"Invalid webhook signature: {e}")
        raise HTTPException(status_code=400, detail="Invalid signature")
    except Exception as e:
        logger.error(f"Webhook processing failed: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail="Webhook processing failed")


# ===== Email Verification & Password Reset ===== 
@app.post("/auth/verify-email", response_model=MessageResponse)
async def verify_email(
    verification: EmailVerificationRequest,
    db: Session = Depends(get_db)
):
    """Verify user email with token."""
    user = db.query(User).filter(
        User.email_verification_token == verification.token,
        User.email_verification_expires > datetime.utcnow()
    ).first()
    
    if not user:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Invalid or expired verification token"
        )
    
    user.email_verified = True
    user.email_verification_token = None
    user.email_verification_expires = None
    db.commit()
    
    logger.info(f"Email verified for user {user.id}")
    
    return MessageResponse(message="Email verified successfully")


@app.post("/auth/resend-verification", response_model=MessageResponse)
async def resend_verification(
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """Resend verification email."""
    if current_user.email_verified:
        return MessageResponse(message="Email already verified")
    
    # Generate new token
    token = generate_token()
    current_user.email_verification_token = token
    current_user.email_verification_expires = datetime.utcnow() + timedelta(hours=24)
    db.commit()
    
    # Send email
    base_url = settings.api_host if settings.debug else "https://yourdomain.com"
    EmailService.send_verification_email(
        user_email=current_user.email,
        username=current_user.username,
        token=token,
        base_url=base_url
    )
    
    return MessageResponse(message="Verification email sent")


@app.post("/auth/forgot-password", response_model=MessageResponse)
async def forgot_password(
    reset_request: PasswordResetRequest,
    db: Session = Depends(get_db)
):
    """Request password reset."""
    user = db.query(User).filter(User.email == reset_request.email).first()
    
    # Always return success to prevent email enumeration
    if not user:
        return MessageResponse(message="If the email exists, a reset link has been sent")
    
    # Generate reset token
    token = generate_token()
    user.password_reset_token = token
    user.password_reset_expires = datetime.utcnow() + timedelta(hours=1)
    db.commit()
    
    # Send email
    base_url = settings.api_host if settings.debug else "https://yourdomain.com"
    EmailService.send_password_reset_email(
        user_email=user.email,
        username=user.username,
        token=token,
        base_url=base_url
    )
    
    logger.info(f"Password reset requested for user {user.id}")
    
    return MessageResponse(message="If the email exists, a reset link has been sent")


@app.post("/auth/reset-password", response_model=MessageResponse)
async def reset_password(
    reset_data: PasswordResetConfirm,
    db: Session = Depends(get_db)
):
    """Reset password with token."""
    user = db.query(User).filter(
        User.password_reset_token == reset_data.token,
        User.password_reset_expires > datetime.utcnow()
    ).first()
    
    if not user:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Invalid or expired reset token"
        )
    
    # Update password
    user.hashed_password = get_password_hash(reset_data.new_password)
    user.password_reset_token = None
    user.password_reset_expires = None
    db.commit()
    
    logger.info(f"Password reset for user {user.id}")
    
    return MessageResponse(message="Password reset successfully")


@app.post("/auth/refresh", response_model=Token)
async def refresh_access_token(
    refresh_request: RefreshTokenRequest,
    db: Session = Depends(get_db)
):
    """Refresh access token using refresh token."""
    # Find refresh token in database
    refresh_token = db.query(RefreshToken).filter(
        RefreshToken.token == refresh_request.refresh_token
    ).first()
    
    if not refresh_token:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid refresh token"
        )
    
    # Verify refresh token
    if not verify_refresh_token(
        refresh_request.refresh_token,
        refresh_token.token,
        refresh_token.expires_at,
        refresh_token.is_revoked
    ):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid or expired refresh token"
        )
    
    # Get user
    user = db.query(User).filter(User.id == refresh_token.user_id).first()
    if not user or not user.is_active:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="User not found or inactive"
        )
    
    # Create new access token
    access_token = create_access_token(data={"sub": user.id})
    
    # Create new refresh token (token rotation)
    new_refresh_token_str = create_refresh_token()
    new_refresh_token = RefreshToken(
        user_id=user.id,
        token=new_refresh_token_str,
        expires_at=datetime.utcnow() + timedelta(days=30)
    )
    
    # Revoke old refresh token
    refresh_token.is_revoked = True
    
    db.add(new_refresh_token)
    db.commit()
    
    logger.info(f"Token refreshed for user {user.id}")
    
    return Token(
        access_token=access_token,
        refresh_token=new_refresh_token_str,
        token_type="bearer",
        expires_in=settings.jwt_expiration_hours * 3600
    )


# ===== Error Handlers =====
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
    """Global exception handler."""
    logger.error(f"Unhandled exception: {exc}", exc_info=True)
    return JSONResponse(
        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        content={"detail": "Internal server error"}
    )


# Import SessionLocal for background tasks
from shared.database import SessionLocal


if __name__ == "__main__":
    import uvicorn
    
    uvicorn.run(
        "main:app",
        host=settings.api_host,
        port=settings.api_port,
        reload=settings.debug,
        log_level="debug" if settings.debug else "info"
    )
