"""
Redis Checkpointer for LangGraph
Persists workflow state to Redis for human-in-the-loop interrupts
"""

import json
import os
from typing import Any, Dict, List, Optional, Sequence, Tuple
from datetime import datetime

import redis.asyncio as redis
from langgraph.checkpoint.base import (
    BaseCheckpointSaver,
    Checkpoint,
    CheckpointMetadata,
    CheckpointTuple,
)


class RedisCheckpointer(BaseCheckpointSaver):
    """
    Redis-based checkpoint saver for LangGraph workflows.

    Stores campaign state in Redis, enabling:
    - Human-in-the-loop approval workflows
    - Campaign pause/resume functionality
    - State persistence across service restarts
    """

    def __init__(
        self,
        redis_url: Optional[str] = None,
        prefix: str = "langgraph:checkpoint:",
        ttl_seconds: int = 86400 * 30,  # 30 days default
    ):
        super().__init__()
        self.redis_url = redis_url or os.getenv(
            "REDIS_URL",
            "redis://localhost:6379/0"
        )
        self.prefix = prefix
        self.ttl_seconds = ttl_seconds
        self._client: Optional[redis.Redis] = None

    async def _get_client(self) -> redis.Redis:
        """Get or create Redis client"""
        if self._client is None:
            self._client = redis.from_url(
                self.redis_url,
                encoding="utf-8",
                decode_responses=True,
            )
        return self._client

    def _make_key(self, thread_id: str, checkpoint_id: str) -> str:
        """Generate Redis key for a checkpoint"""
        return f"{self.prefix}{thread_id}:{checkpoint_id}"

    def _make_thread_key(self, thread_id: str) -> str:
        """Generate Redis key for thread metadata"""
        return f"{self.prefix}thread:{thread_id}"

    async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]:
        """Get a checkpoint tuple by config (implements BaseCheckpointSaver.aget_tuple)"""
        thread_id = config.get("configurable", {}).get("thread_id")
        checkpoint_id = config.get("configurable", {}).get("checkpoint_id")

        if not thread_id:
            return None

        client = await self._get_client()

        # If no checkpoint_id, get the latest
        if not checkpoint_id:
            thread_key = self._make_thread_key(thread_id)
            checkpoint_id = await client.get(f"{thread_key}:latest")
            if not checkpoint_id:
                return None

        key = self._make_key(thread_id, checkpoint_id)
        data = await client.get(key)

        if not data:
            return None

        try:
            parsed = json.loads(data)
            checkpoint = Checkpoint(
                v=parsed.get("v", 1),
                id=parsed["id"],
                ts=parsed["ts"],
                channel_values=parsed.get("channel_values", {}),
                channel_versions=parsed.get("channel_versions", {}),
                versions_seen=parsed.get("versions_seen", {}),
                pending_sends=parsed.get("pending_sends", []),
            )
            metadata = CheckpointMetadata(
                source=parsed.get("metadata", {}).get("source", "unknown"),
                step=parsed.get("metadata", {}).get("step", 0),
                writes=parsed.get("metadata", {}).get("writes", {}),
            )
            return CheckpointTuple(
                config=config,
                checkpoint=checkpoint,
                metadata=metadata,
                parent_config=parsed.get("parent_config"),
            )
        except (json.JSONDecodeError, KeyError) as e:
            print(f"[RedisCheckpointer] Failed to parse checkpoint: {e}")
            return None

    async def aput(
        self,
        config: Dict[str, Any],
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
        new_versions: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Save a checkpoint"""
        thread_id = config.get("configurable", {}).get("thread_id")
        if not thread_id:
            raise ValueError("thread_id required in config")

        checkpoint_id = checkpoint["id"]
        client = await self._get_client()

        # Serialize checkpoint
        data = {
            "v": checkpoint.get("v", 1),
            "id": checkpoint_id,
            "ts": checkpoint["ts"],
            "channel_values": checkpoint.get("channel_values", {}),
            "channel_versions": checkpoint.get("channel_versions", {}),
            "versions_seen": checkpoint.get("versions_seen", {}),
            "pending_sends": checkpoint.get("pending_sends", []),
            "metadata": {
                "source": metadata.get("source", "unknown") if metadata else "unknown",
                "step": metadata.get("step", 0) if metadata else 0,
                "writes": metadata.get("writes", {}) if metadata else {},
            },
            "parent_config": config.get("configurable", {}).get("parent_config"),
            "saved_at": datetime.utcnow().isoformat(),
        }

        key = self._make_key(thread_id, checkpoint_id)
        thread_key = self._make_thread_key(thread_id)

        # Save checkpoint
        await client.set(key, json.dumps(data), ex=self.ttl_seconds)

        # Update latest pointer
        await client.set(f"{thread_key}:latest", checkpoint_id, ex=self.ttl_seconds)

        # Add to checkpoint list for thread
        await client.lpush(f"{thread_key}:list", checkpoint_id)
        await client.ltrim(f"{thread_key}:list", 0, 99)  # Keep last 100
        await client.expire(f"{thread_key}:list", self.ttl_seconds)

        return {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_id": checkpoint_id,
            }
        }

    async def aput_writes(
        self,
        config: Dict[str, Any],
        writes: Sequence[Tuple[str, Any]],
        task_id: str,
    ) -> None:
        """Store pending writes (for interrupted workflows)"""
        thread_id = config.get("configurable", {}).get("thread_id")
        if not thread_id:
            return

        client = await self._get_client()
        thread_key = self._make_thread_key(thread_id)

        pending_writes = {
            "task_id": task_id,
            "writes": [(channel, value) for channel, value in writes],
            "created_at": datetime.utcnow().isoformat(),
        }

        await client.set(
            f"{thread_key}:pending_writes:{task_id}",
            json.dumps(pending_writes, default=str),
            ex=self.ttl_seconds,
        )

    async def alist(
        self,
        config: Dict[str, Any],
        *,
        filter: Optional[Dict[str, Any]] = None,
        before: Optional[Dict[str, Any]] = None,
        limit: Optional[int] = None,
    ):
        """List checkpoints for a thread"""
        thread_id = config.get("configurable", {}).get("thread_id")
        if not thread_id:
            return

        client = await self._get_client()
        thread_key = self._make_thread_key(thread_id)

        # Get checkpoint IDs
        checkpoint_ids = await client.lrange(
            f"{thread_key}:list",
            0,
            (limit or 100) - 1
        )

        for checkpoint_id in checkpoint_ids:
            checkpoint_config = {
                "configurable": {
                    "thread_id": thread_id,
                    "checkpoint_id": checkpoint_id,
                }
            }
            result = await self.aget_tuple(checkpoint_config)
            if result:
                yield result

    # Sync methods (for compatibility)
    def get(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]:
        """Sync version of aget - raises NotImplementedError, use async"""
        raise NotImplementedError("Use async method aget() instead")

    def put(
        self,
        config: Dict[str, Any],
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
        new_versions: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Sync version of aput - raises NotImplementedError, use async"""
        raise NotImplementedError("Use async method aput() instead")

    def put_writes(
        self,
        config: Dict[str, Any],
        writes: Sequence[Tuple[str, Any]],
        task_id: str,
    ) -> None:
        """Sync version of aput_writes - raises NotImplementedError, use async"""
        raise NotImplementedError("Use async method aput_writes() instead")

    def list(
        self,
        config: Dict[str, Any],
        *,
        filter: Optional[Dict[str, Any]] = None,
        before: Optional[Dict[str, Any]] = None,
        limit: Optional[int] = None,
    ):
        """Sync version of alist - raises NotImplementedError, use async"""
        raise NotImplementedError("Use async method alist() instead")

    # Utility methods
    async def delete_thread(self, thread_id: str) -> bool:
        """Delete all checkpoints for a thread"""
        client = await self._get_client()
        thread_key = self._make_thread_key(thread_id)

        # Get all checkpoint IDs
        checkpoint_ids = await client.lrange(f"{thread_key}:list", 0, -1)

        # Delete all checkpoints
        keys_to_delete = [
            self._make_key(thread_id, cid) for cid in checkpoint_ids
        ]
        keys_to_delete.extend([
            f"{thread_key}:latest",
            f"{thread_key}:list",
        ])

        # Find and delete pending writes
        pending_keys = await client.keys(f"{thread_key}:pending_writes:*")
        keys_to_delete.extend(pending_keys)

        if keys_to_delete:
            await client.delete(*keys_to_delete)

        return True

    async def get_thread_ids(self, pattern: str = "*") -> List[str]:
        """Get all thread IDs matching a pattern"""
        client = await self._get_client()
        keys = await client.keys(f"{self.prefix}thread:{pattern}:latest")
        return [
            key.replace(f"{self.prefix}thread:", "").replace(":latest", "")
            for key in keys
        ]

    async def close(self):
        """Close Redis connection"""
        if self._client:
            await self._client.close()
            self._client = None
