from dataclasses import dataclass, field
from typing import Optional, List, Dict, Set, Any, Union

import numpy as np
from numpy.typing import NDArray
from chromadb.api.types import (
    Embeddings,
    IDs,
    Include,
    SparseVector,
    TYPE_KEY,
    SPARSE_VECTOR_TYPE_VALUE,
    normalize_embeddings,
    validate_embeddings,
)
from chromadb.types import (
    Collection,
    RequestVersionContext,
    Segment,
)


@dataclass
class Scan:
    collection: Collection
    knn: Segment
    metadata: Segment
    record: Segment

    @property
    def version(self) -> RequestVersionContext:
        return RequestVersionContext(
            collection_version=self.collection.version,
            log_position=self.collection.log_position,
        )


# Where expression types for filtering
@dataclass
class Where:
    """Base class for Where expressions (algebraic data type).

    Supports logical operators for combining conditions:
        - AND: where1 & where2
        - OR: where1 | where2

    Examples:
        # Simple conditions
        where1 = Key("status") == "active"
        where2 = Key("score") > 0.5

        # Combining with AND
        combined_and = where1 & where2

        # Combining with OR
        combined_or = where1 | where2

        # Complex expressions
        complex_where = (Key("status") == "active") & ((Key("score") > 0.5) | (Key("priority") == "high"))
    """

    def to_dict(self) -> Dict[str, Any]:
        """Convert the Where expression to a dictionary for JSON serialization"""
        raise NotImplementedError("Subclasses must implement to_dict()")

    @staticmethod
    def from_dict(data: Dict[str, Any]) -> "Where":
        """Create Where expression from dictionary.

        Supports MongoDB-style query operators:
        - {"field": "value"} -> Key("field") == "value" (shorthand for equality)
        - {"field": {"$eq": value}} -> Key("field") == value
        - {"field": {"$ne": value}} -> Key("field") != value
        - {"field": {"$gt": value}} -> Key("field") > value
        - {"field": {"$gte": value}} -> Key("field") >= value
        - {"field": {"$lt": value}} -> Key("field") < value
        - {"field": {"$lte": value}} -> Key("field") <= value
        - {"field": {"$in": [values]}} -> Key("field").is_in([values])
        - {"field": {"$nin": [values]}} -> Key("field").not_in([values])
        - {"field": {"$contains": "text"}} -> Key("field").contains("text")
        - {"field": {"$not_contains": "text"}} -> Key("field").not_contains("text")
        - {"field": {"$regex": "pattern"}} -> Key("field").regex("pattern")
        - {"field": {"$not_regex": "pattern"}} -> Key("field").not_regex("pattern")
        - {"$and": [conditions]} -> condition1 & condition2 & ...
        - {"$or": [conditions]} -> condition1 | condition2 | ...
        """
        if not isinstance(data, dict):
            raise TypeError(f"Expected dict for Where, got {type(data).__name__}")

        if not data:
            raise ValueError("Where dict cannot be empty")

        # Handle logical operators
        if "$and" in data:
            if not isinstance(data["$and"], list):
                raise TypeError(
                    f"$and must be a list, got {type(data['$and']).__name__}"
                )
            if len(data["$and"]) == 0:
                raise ValueError("$and requires at least one condition")
            if len(data) > 1:
                raise ValueError(
                    "$and cannot be combined with other fields in the same dict"
                )

            conditions = [Where.from_dict(c) for c in data["$and"]]
            if len(conditions) == 1:
                return conditions[0]
            result = conditions[0]
            for c in conditions[1:]:
                result = result & c
            return result

        elif "$or" in data:
            if not isinstance(data["$or"], list):
                raise TypeError(f"$or must be a list, got {type(data['$or']).__name__}")
            if len(data["$or"]) == 0:
                raise ValueError("$or requires at least one condition")
            if len(data) > 1:
                raise ValueError(
                    "$or cannot be combined with other fields in the same dict"
                )

            conditions = [Where.from_dict(c) for c in data["$or"]]
            if len(conditions) == 1:
                return conditions[0]
            result = conditions[0]
            for c in conditions[1:]:
                result = result | c
            return result

        else:
            # Single field condition
            if len(data) != 1:
                raise ValueError(
                    f"Where dict must contain exactly one field, got {len(data)}"
                )

            field, condition = next(iter(data.items()))

            if not isinstance(field, str):
                raise TypeError(
                    f"Field name must be a string, got {type(field).__name__}"
                )

            if isinstance(condition, dict):
                # Operator-based condition
                if not condition:
                    raise ValueError(
                        f"Operator dict for field '{field}' cannot be empty"
                    )
                if len(condition) != 1:
                    raise ValueError(
                        f"Operator dict for field '{field}' must contain exactly one operator"
                    )

                op, value = next(iter(condition.items()))

                if op == "$eq":
                    return Key(field) == value
                elif op == "$ne":
                    return Key(field) != value
                elif op == "$gt":
                    return Key(field) > value
                elif op == "$gte":
                    return Key(field) >= value
                elif op == "$lt":
                    return Key(field) < value
                elif op == "$lte":
                    return Key(field) <= value
                elif op == "$in":
                    if not isinstance(value, list):
                        raise TypeError(
                            f"$in requires a list, got {type(value).__name__}"
                        )
                    return Key(field).is_in(value)
                elif op == "$nin":
                    if not isinstance(value, list):
                        raise TypeError(
                            f"$nin requires a list, got {type(value).__name__}"
                        )
                    return Key(field).not_in(value)
                elif op == "$contains":
                    if not isinstance(value, str):
                        raise TypeError(
                            f"$contains requires a string, got {type(value).__name__}"
                        )
                    return Key(field).contains(value)
                elif op == "$not_contains":
                    if not isinstance(value, str):
                        raise TypeError(
                            f"$not_contains requires a string, got {type(value).__name__}"
                        )
                    return Key(field).not_contains(value)
                elif op == "$regex":
                    if not isinstance(value, str):
                        raise TypeError(
                            f"$regex requires a string pattern, got {type(value).__name__}"
                        )
                    return Key(field).regex(value)
                elif op == "$not_regex":
                    if not isinstance(value, str):
                        raise TypeError(
                            f"$not_regex requires a string pattern, got {type(value).__name__}"
                        )
                    return Key(field).not_regex(value)
                else:
                    raise ValueError(f"Unknown operator: {op}")
            else:
                # Direct value is shorthand for equality
                return Key(field) == condition

    def __and__(self, other: "Where") -> "And":
        """Overload & operator for AND"""
        # If self is already an And, extend it
        if isinstance(self, And):
            # If other is also And, combine all conditions
            if isinstance(other, And):
                return And(self.conditions + other.conditions)
            return And(self.conditions + [other])
        # If other is And, prepend self to it
        elif isinstance(other, And):
            return And([self] + other.conditions)
        # Create new And with both conditions
        return And([self, other])

    def __or__(self, other: "Where") -> "Or":
        """Overload | operator for OR"""
        # If self is already an Or, extend it
        if isinstance(self, Or):
            # If other is also Or, combine all conditions
            if isinstance(other, Or):
                return Or(self.conditions + other.conditions)
            return Or(self.conditions + [other])
        # If other is Or, prepend self to it
        elif isinstance(other, Or):
            return Or([self] + other.conditions)
        # Create new Or with both conditions
        return Or([self, other])


@dataclass
class And(Where):
    """Logical AND of multiple where conditions"""

    conditions: List[Where]

    def to_dict(self) -> Dict[str, Any]:
        return {"$and": [c.to_dict() for c in self.conditions]}


@dataclass
class Or(Where):
    """Logical OR of multiple where conditions"""

    conditions: List[Where]

    def to_dict(self) -> Dict[str, Any]:
        return {"$or": [c.to_dict() for c in self.conditions]}


@dataclass
class Eq(Where):
    """Equality comparison"""

    key: str
    value: Any

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$eq": self.value}}


@dataclass
class Ne(Where):
    """Not equal comparison"""

    key: str
    value: Any

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$ne": self.value}}


@dataclass
class Gt(Where):
    """Greater than comparison"""

    key: str
    value: Any

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$gt": self.value}}


@dataclass
class Gte(Where):
    """Greater than or equal comparison"""

    key: str
    value: Any

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$gte": self.value}}


@dataclass
class Lt(Where):
    """Less than comparison"""

    key: str
    value: Any

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$lt": self.value}}


@dataclass
class Lte(Where):
    """Less than or equal comparison"""

    key: str
    value: Any

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$lte": self.value}}


@dataclass
class In(Where):
    """In comparison - value is in a list"""

    key: str
    values: List[Any]

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$in": self.values}}


@dataclass
class Nin(Where):
    """Not in comparison - value is not in a list"""

    key: str
    values: List[Any]

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$nin": self.values}}


@dataclass
class Contains(Where):
    """Contains comparison for document content"""

    key: str
    content: str

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$contains": self.content}}


@dataclass
class NotContains(Where):
    """Not contains comparison for document content"""

    key: str
    content: str

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$not_contains": self.content}}


@dataclass
class Regex(Where):
    """Regular expression matching"""

    key: str
    pattern: str

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$regex": self.pattern}}


@dataclass
class NotRegex(Where):
    """Negative regular expression matching"""

    key: str
    pattern: str

    def to_dict(self) -> Dict[str, Any]:
        return {self.key: {"$not_regex": self.pattern}}


# Field proxy for building Where conditions
class Key:
    """Field proxy for building Where conditions with operator overloading.

    The Key class allows for readable field references using either:
    1. Predefined constants for special fields: K.EMBEDDING, K.DOCUMENT, K.SCORE, etc.
    2. String literals with # prefix for special fields: Key("#embedding")
    3. Metadata field names without # prefix: Key("my_metadata_field")

    Predefined field constants (special fields with # prefix):
        Key.ID - ID field (equivalent to Key("#id"))
        Key.DOCUMENT - Document field (equivalent to Key("#document"))
        Key.EMBEDDING - Embedding field (equivalent to Key("#embedding"))
        Key.METADATA - Metadata field (equivalent to Key("#metadata"))
        Key.SCORE - Score field (equivalent to Key("#score"))

    Note: K is an alias for Key, so you can use K.DOCUMENT or Key.DOCUMENT interchangeably.

    Examples:
        # Using predefined keys with K alias for special fields
        from chromadb.execution.expression import K
        K.DOCUMENT.contains("search text")  # Searches document field

        # Custom metadata field names (without # prefix)
        K("status") == "active"  # Metadata field named "status"
        K("category").is_in(["science", "tech"])  # Metadata field named "category"
        K("sparse_embedding")  # Example: metadata field (could store anything)

        # Using with Knn for different fields
        Knn(query=[0.1, 0.2])  # Default: searches "#embedding"
        Knn(query=[0.1, 0.2], key=K.EMBEDDING)  # Explicit: searches "#embedding"
        Knn(query=sparse, key="sparse_embedding")  # Example: searches a metadata field

        # Combining conditions
        (K("status") == "active") & (K.SCORE > 0.5)
    """

    # Predefined key constants (initialized after class definition)
    ID: "Key"
    DOCUMENT: "Key"
    EMBEDDING: "Key"
    METADATA: "Key"
    SCORE: "Key"

    def __init__(self, name: str):
        self.name = name

    def __hash__(self) -> int:
        """Make Key hashable for use in sets"""
        return hash(self.name)

    # Comparison operators
    def __eq__(self, value: Any) -> Eq:  # type: ignore[override]
        """Equality: Key('field') == value"""
        return Eq(self.name, value)

    def __ne__(self, value: Any) -> Ne:  # type: ignore[override]
        """Not equal: Key('field') != value"""
        return Ne(self.name, value)

    def __gt__(self, value: Any) -> Gt:
        """Greater than: Key('field') > value"""
        return Gt(self.name, value)

    def __ge__(self, value: Any) -> Gte:
        """Greater than or equal: Key('field') >= value"""
        return Gte(self.name, value)

    def __lt__(self, value: Any) -> Lt:
        """Less than: Key('field') < value"""
        return Lt(self.name, value)

    def __le__(self, value: Any) -> Lte:
        """Less than or equal: Key('field') <= value"""
        return Lte(self.name, value)

    # Builder methods for operations without operators
    def is_in(self, values: List[Any]) -> In:
        """Check if field value is in list: Key('field').is_in(['a', 'b'])"""
        return In(self.name, values)

    def not_in(self, values: List[Any]) -> Nin:
        """Check if field value is not in list: Key('field').not_in(['a', 'b'])"""
        return Nin(self.name, values)

    def regex(self, pattern: str) -> Regex:
        """Match field against regex: Key('field').regex('^pattern')"""
        return Regex(self.name, pattern)

    def not_regex(self, pattern: str) -> NotRegex:
        """Field should not match regex: Key('field').not_regex('^pattern')"""
        return NotRegex(self.name, pattern)

    def contains(self, content: str) -> Contains:
        """Check if field contains text: Key('field').contains('text')"""
        return Contains(self.name, content)

    def not_contains(self, content: str) -> NotContains:
        """Check if field doesn't contain text: Key('field').not_contains('text')"""
        return NotContains(self.name, content)


# Initialize predefined key constants
Key.ID = Key("#id")
Key.DOCUMENT = Key("#document")
Key.EMBEDDING = Key("#embedding")
Key.METADATA = Key("#metadata")
Key.SCORE = Key("#score")

# Alias for Key
K = Key


@dataclass
class Filter:
    user_ids: Optional[IDs] = None
    where: Optional[Any] = None  # Old Where type from chromadb.types
    where_document: Optional[Any] = None  # Old WhereDocument type


@dataclass
class KNN:
    embeddings: Embeddings
    fetch: int


@dataclass
class Limit:
    offset: int = 0
    limit: Optional[int] = None

    def to_dict(self) -> Dict[str, Any]:
        """Convert the Limit to a dictionary for JSON serialization"""
        result = {"offset": self.offset}
        if self.limit is not None:
            result["limit"] = self.limit
        return result

    @staticmethod
    def from_dict(data: Dict[str, Any]) -> "Limit":
        """Create Limit from dictionary.

        Examples:
        - {"offset": 10} -> Limit(offset=10)
        - {"offset": 10, "limit": 20} -> Limit(offset=10, limit=20)
        - {"limit": 20} -> Limit(offset=0, limit=20)
        """
        if not isinstance(data, dict):
            raise TypeError(f"Expected dict for Limit, got {type(data).__name__}")

        offset = data.get("offset", 0)
        if not isinstance(offset, int):
            raise TypeError(
                f"Limit offset must be an integer, got {type(offset).__name__}"
            )
        if offset < 0:
            raise ValueError(f"Limit offset must be non-negative, got {offset}")

        limit = data.get("limit")
        if limit is not None:
            if not isinstance(limit, int):
                raise TypeError(
                    f"Limit limit must be an integer, got {type(limit).__name__}"
                )
            if limit <= 0:
                raise ValueError(f"Limit limit must be positive, got {limit}")

        # Check for unexpected keys
        allowed_keys = {"offset", "limit"}
        unexpected_keys = set(data.keys()) - allowed_keys
        if unexpected_keys:
            raise ValueError(f"Unexpected keys in Limit dict: {unexpected_keys}")

        return Limit(offset=offset, limit=limit)


@dataclass
class Projection:
    document: bool = False
    embedding: bool = False
    metadata: bool = False
    rank: bool = False
    uri: bool = False

    @property
    def included(self) -> Include:
        includes = list()
        if self.document:
            includes.append("documents")
        if self.embedding:
            includes.append("embeddings")
        if self.metadata:
            includes.append("metadatas")
        if self.rank:
            includes.append("distances")
        if self.uri:
            includes.append("uris")
        return includes  # type: ignore[return-value]


# Rank expression types for hybrid search
@dataclass
class Rank:
    """Base class for Rank expressions (algebraic data type).

    Supports arithmetic operations for combining rank expressions:
        - Addition: rank1 + rank2, rank + 0.5
        - Subtraction: rank1 - rank2, rank - 0.5
        - Multiplication: rank1 * rank2, rank * 0.8
        - Division: rank1 / rank2, rank / 2.0
        - Negation: -rank
        - Absolute value: abs(rank)

    Supports mathematical functions:
        - Exponential: rank.exp()
        - Logarithm: rank.log()
        - Maximum: rank.max(other)
        - Minimum: rank.min(other)

    Examples:
        # Weighted combination
        Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2

        # Normalization
        Knn(query=[0.1, 0.2]) / Val(10.0)

        # Clamping
        Knn(query=[0.1, 0.2]).min(1.0).max(0.0)
    """

    def to_dict(self) -> Dict[str, Any]:
        """Convert the Score expression to a dictionary for JSON serialization"""
        raise NotImplementedError("Subclasses must implement to_dict()")

    @staticmethod
    def from_dict(data: Dict[str, Any]) -> "Rank":
        """Create Rank expression from dictionary.

        Supports operators:
        - {"$val": number} -> Val(number)
        - {"$knn": {...}} -> Knn(...)
        - {"$sum": [ranks]} -> rank1 + rank2 + ...
        - {"$sub": {"left": ..., "right": ...}} -> left - right
        - {"$mul": [ranks]} -> rank1 * rank2 * ...
        - {"$div": {"left": ..., "right": ...}} -> left / right
        - {"$abs": rank} -> abs(rank)
        - {"$exp": rank} -> rank.exp()
        - {"$log": rank} -> rank.log()
        - {"$max": [ranks]} -> rank1.max(rank2).max(rank3)...
        - {"$min": [ranks]} -> rank1.min(rank2).min(rank3)...
        """
        if not isinstance(data, dict):
            raise TypeError(f"Expected dict for Rank, got {type(data).__name__}")

        if not data:
            raise ValueError("Rank dict cannot be empty")

        if len(data) != 1:
            raise ValueError(
                f"Rank dict must contain exactly one operator, got {len(data)}"
            )

        op = next(iter(data.keys()))

        if op == "$val":
            value = data["$val"]
            if not isinstance(value, (int, float)):
                raise TypeError(f"$val requires a number, got {type(value).__name__}")
            return Val(value)

        elif op == "$knn":
            knn_data = data["$knn"]
            if not isinstance(knn_data, dict):
                raise TypeError(f"$knn requires a dict, got {type(knn_data).__name__}")

            if "query" not in knn_data:
                raise ValueError("$knn requires 'query' field")

            query = knn_data["query"]

            if isinstance(query, dict):
                # SparseVector case - deserialize from transport format
                if query.get(TYPE_KEY) == SPARSE_VECTOR_TYPE_VALUE:
                    query = SparseVector.from_dict(query)
                else:
                    # Old format or invalid - try to construct directly
                    raise ValueError(
                        f"Expected dict with {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {query}"
                    )

            elif isinstance(query, (list, tuple, np.ndarray)):
                # Dense vector case - normalize then validate
                normalized = normalize_embeddings(query)
                if not normalized or len(normalized) > 1:
                    raise ValueError("$knn requires exactly one query embedding")

                # Validate the normalized version
                validate_embeddings(normalized)

                query = normalized[0]

            else:
                raise TypeError(
                    f"$knn query must be a list, numpy array, or SparseVector dict, got {type(query).__name__}"
                )

            key = knn_data.get("key", "#embedding")
            if not isinstance(key, str):
                raise TypeError(f"$knn key must be a string, got {type(key).__name__}")

            limit = knn_data.get("limit", 16)
            if not isinstance(limit, int):
                raise TypeError(
                    f"$knn limit must be an integer, got {type(limit).__name__}"
                )
            if limit <= 0:
                raise ValueError(f"$knn limit must be positive, got {limit}")

            return_rank = knn_data.get("return_rank", False)
            if not isinstance(return_rank, bool):
                raise TypeError(
                    f"$knn return_rank must be a boolean, got {type(return_rank).__name__}"
                )

            return Knn(
                query=query,
                key=key,
                limit=limit,
                default=knn_data.get("default"),
                return_rank=return_rank,
            )

        elif op == "$sum":
            ranks_data = data["$sum"]
            if not isinstance(ranks_data, (list, tuple)):
                raise TypeError(
                    f"$sum requires a list, got {type(ranks_data).__name__}"
                )
            if len(ranks_data) < 2:
                raise ValueError(
                    f"$sum requires at least 2 ranks, got {len(ranks_data)}"
                )

            ranks = [Rank.from_dict(r) for r in ranks_data]
            result = ranks[0]
            for r in ranks[1:]:
                result = result + r
            return result

        elif op == "$sub":
            sub_data = data["$sub"]
            if not isinstance(sub_data, dict):
                raise TypeError(
                    f"$sub requires a dict with 'left' and 'right', got {type(sub_data).__name__}"
                )
            if "left" not in sub_data or "right" not in sub_data:
                raise ValueError("$sub requires 'left' and 'right' fields")

            left = Rank.from_dict(sub_data["left"])
            right = Rank.from_dict(sub_data["right"])
            return left - right

        elif op == "$mul":
            ranks_data = data["$mul"]
            if not isinstance(ranks_data, (list, tuple)):
                raise TypeError(
                    f"$mul requires a list, got {type(ranks_data).__name__}"
                )
            if len(ranks_data) < 2:
                raise ValueError(
                    f"$mul requires at least 2 ranks, got {len(ranks_data)}"
                )

            ranks = [Rank.from_dict(r) for r in ranks_data]
            result = ranks[0]
            for r in ranks[1:]:
                result = result * r
            return result

        elif op == "$div":
            div_data = data["$div"]
            if not isinstance(div_data, dict):
                raise TypeError(
                    f"$div requires a dict with 'left' and 'right', got {type(div_data).__name__}"
                )
            if "left" not in div_data or "right" not in div_data:
                raise ValueError("$div requires 'left' and 'right' fields")

            left = Rank.from_dict(div_data["left"])
            right = Rank.from_dict(div_data["right"])
            return left / right

        elif op == "$abs":
            child_data = data["$abs"]
            if not isinstance(child_data, dict):
                raise TypeError(
                    f"$abs requires a rank dict, got {type(child_data).__name__}"
                )
            return abs(Rank.from_dict(child_data))

        elif op == "$exp":
            child_data = data["$exp"]
            if not isinstance(child_data, dict):
                raise TypeError(
                    f"$exp requires a rank dict, got {type(child_data).__name__}"
                )
            return Rank.from_dict(child_data).exp()

        elif op == "$log":
            child_data = data["$log"]
            if not isinstance(child_data, dict):
                raise TypeError(
                    f"$log requires a rank dict, got {type(child_data).__name__}"
                )
            return Rank.from_dict(child_data).log()

        elif op == "$max":
            ranks_data = data["$max"]
            if not isinstance(ranks_data, (list, tuple)):
                raise TypeError(
                    f"$max requires a list, got {type(ranks_data).__name__}"
                )
            if len(ranks_data) < 2:
                raise ValueError(
                    f"$max requires at least 2 ranks, got {len(ranks_data)}"
                )

            ranks = [Rank.from_dict(r) for r in ranks_data]
            result = ranks[0]
            for r in ranks[1:]:
                result = result.max(r)
            return result

        elif op == "$min":
            ranks_data = data["$min"]
            if not isinstance(ranks_data, (list, tuple)):
                raise TypeError(
                    f"$min requires a list, got {type(ranks_data).__name__}"
                )
            if len(ranks_data) < 2:
                raise ValueError(
                    f"$min requires at least 2 ranks, got {len(ranks_data)}"
                )

            ranks = [Rank.from_dict(r) for r in ranks_data]
            result = ranks[0]
            for r in ranks[1:]:
                result = result.min(r)
            return result

        else:
            raise ValueError(f"Unknown rank operator: {op}")

    # Arithmetic operators
    def __add__(self, other: Union["Rank", float, int]) -> "Sum":
        """Addition: rank1 + rank2 or rank + value"""
        other_rank = Val(other) if isinstance(other, (int, float)) else other
        # Flatten if already Sum
        if isinstance(self, Sum):
            if isinstance(other_rank, Sum):
                return Sum(self.ranks + other_rank.ranks)
            return Sum(self.ranks + [other_rank])
        elif isinstance(other_rank, Sum):
            return Sum([self] + other_rank.ranks)
        return Sum([self, other_rank])

    def __radd__(self, other: Union[float, int]) -> "Sum":
        """Right addition: value + rank"""
        return Val(other) + self

    def __sub__(self, other: Union["Rank", float, int]) -> "Sub":
        """Subtraction: rank1 - rank2 or rank - value"""
        other_rank = Val(other) if isinstance(other, (int, float)) else other
        return Sub(self, other_rank)

    def __rsub__(self, other: Union[float, int]) -> "Sub":
        """Right subtraction: value - rank"""
        return Sub(Val(other), self)

    def __mul__(self, other: Union["Rank", float, int]) -> "Mul":
        """Multiplication: rank1 * rank2 or rank * value"""
        other_rank = Val(other) if isinstance(other, (int, float)) else other
        # Flatten if already Mul
        if isinstance(self, Mul):
            if isinstance(other_rank, Mul):
                return Mul(self.ranks + other_rank.ranks)
            return Mul(self.ranks + [other_rank])
        elif isinstance(other_rank, Mul):
            return Mul([self] + other_rank.ranks)
        return Mul([self, other_rank])

    def __rmul__(self, other: Union[float, int]) -> "Mul":
        """Right multiplication: value * rank"""
        return Val(other) * self

    def __truediv__(self, other: Union["Rank", float, int]) -> "Div":
        """Division: rank1 / rank2 or rank / value"""
        other_rank = Val(other) if isinstance(other, (int, float)) else other
        return Div(self, other_rank)

    def __rtruediv__(self, other: Union[float, int]) -> "Div":
        """Right division: value / rank"""
        return Div(Val(other), self)

    def __neg__(self) -> "Mul":
        """Negation: -rank (equivalent to -1 * rank)"""
        return Mul([Val(-1), self])

    def __abs__(self) -> "Abs":
        """Absolute value: abs(rank)"""
        return Abs(self)

    def abs(self) -> "Abs":
        """Absolute value builder: rank.abs()"""
        return Abs(self)

    # Builder methods for functions
    def exp(self) -> "Exp":
        """Exponential: e^rank"""
        return Exp(self)

    def log(self) -> "Log":
        """Natural logarithm: ln(rank)"""
        return Log(self)

    def max(self, other: Union["Rank", float, int]) -> "Max":
        """Maximum of this rank and another: rank.max(rank2)"""
        other_rank = Val(other) if isinstance(other, (int, float)) else other

        # Flatten if already Max
        if isinstance(self, Max):
            if isinstance(other_rank, Max):
                return Max(self.ranks + other_rank.ranks)
            return Max(self.ranks + [other_rank])
        elif isinstance(other_rank, Max):
            return Max([self] + other_rank.ranks)
        return Max([self, other_rank])

    def min(self, other: Union["Rank", float, int]) -> "Min":
        """Minimum of this rank and another: rank.min(rank2)"""
        other_rank = Val(other) if isinstance(other, (int, float)) else other

        # Flatten if already Min
        if isinstance(self, Min):
            if isinstance(other_rank, Min):
                return Min(self.ranks + other_rank.ranks)
            return Min(self.ranks + [other_rank])
        elif isinstance(other_rank, Min):
            return Min([self] + other_rank.ranks)
        return Min([self, other_rank])


@dataclass
class Abs(Rank):
    """Absolute value of a rank"""

    rank: Rank

    def to_dict(self) -> Dict[str, Any]:
        return {"$abs": self.rank.to_dict()}


@dataclass
class Div(Rank):
    """Division of two ranks"""

    left: Rank
    right: Rank

    def to_dict(self) -> Dict[str, Any]:
        return {"$div": {"left": self.left.to_dict(), "right": self.right.to_dict()}}


@dataclass
class Exp(Rank):
    """Exponentiation of a rank"""

    rank: Rank

    def to_dict(self) -> Dict[str, Any]:
        return {"$exp": self.rank.to_dict()}


@dataclass
class Log(Rank):
    """Logarithm of a rank"""

    rank: Rank

    def to_dict(self) -> Dict[str, Any]:
        return {"$log": self.rank.to_dict()}


@dataclass
class Max(Rank):
    """Maximum of multiple ranks"""

    ranks: List[Rank]

    def to_dict(self) -> Dict[str, Any]:
        return {"$max": [r.to_dict() for r in self.ranks]}


@dataclass
class Min(Rank):
    """Minimum of multiple ranks"""

    ranks: List[Rank]

    def to_dict(self) -> Dict[str, Any]:
        return {"$min": [r.to_dict() for r in self.ranks]}


@dataclass
class Mul(Rank):
    """Multiplication of multiple ranks"""

    ranks: List[Rank]

    def to_dict(self) -> Dict[str, Any]:
        return {"$mul": [r.to_dict() for r in self.ranks]}


@dataclass
class Knn(Rank):
    """KNN-based ranking

    Args:
        query: The query for KNN search. Can be:
               - A string (will be automatically embedded using the collection's embedding function)
               - A dense vector (list or numpy array)
               - A sparse vector (SparseVector dict)
        key: The embedding key to search against. Can be:
             - "#embedding" (default) - searches the main embedding field
             - A metadata field name (e.g., "my_custom_field") - searches that metadata field
        limit: Maximum number of results to consider (default: 16)
        default: Default score for records not in KNN results (default: None)
        return_rank: If True, return the rank position (0, 1, 2, ...) instead of distance (default: False)

    Examples:
        # Search with string query (automatically embedded)
        Knn(query="hello world")  # Will use collection's embedding function

        # Search main embeddings with vectors (equivalent forms)
        Knn(query=[0.1, 0.2])  # Uses default key="#embedding"
        Knn(query=[0.1, 0.2], key=K.EMBEDDING)
        Knn(query=[0.1, 0.2], key="#embedding")

        # Search sparse embeddings stored in metadata with string
        Knn(query="hello world", key="custom_embedding")  # Will use schema's embedding function

        # Search sparse embeddings stored in metadata with vector
        Knn(query=my_vector, key="custom_embedding")  # Example: searches a metadata field
    """

    query: Union[
        str,
        List[float],
        SparseVector,
        "NDArray[np.float32]",
        "NDArray[np.float64]",
        "NDArray[np.int32]",
    ]
    key: str = "#embedding"
    limit: int = 16
    default: Optional[float] = None
    return_rank: bool = False

    def to_dict(self) -> Dict[str, Any]:
        # Convert to transport format
        query_value = self.query
        if isinstance(query_value, SparseVector):
            # Convert SparseVector dataclass to transport dict
            query_value = query_value.to_dict()
        elif isinstance(query_value, np.ndarray):
            # Convert numpy array to list
            query_value = query_value.tolist()

        # Build result dict - only include non-default values to keep JSON clean
        result = {"query": query_value, "key": self.key, "limit": self.limit}

        # Only include optional fields if they're set to non-default values
        if self.default is not None:
            result["default"] = self.default  # type: ignore[assignment]
        if self.return_rank:  # Only include if True (non-default)
            result["return_rank"] = self.return_rank

        return {"$knn": result}


@dataclass
class Sub(Rank):
    """Subtraction of two ranks"""

    left: Rank
    right: Rank

    def to_dict(self) -> Dict[str, Any]:
        return {"$sub": {"left": self.left.to_dict(), "right": self.right.to_dict()}}


@dataclass
class Sum(Rank):
    """Summation of multiple ranks"""

    ranks: List[Rank]

    def to_dict(self) -> Dict[str, Any]:
        return {"$sum": [r.to_dict() for r in self.ranks]}


@dataclass
class Val(Rank):
    """Constant rank value"""

    value: float

    def to_dict(self) -> Dict[str, Any]:
        return {"$val": self.value}


@dataclass
class Rrf(Rank):
    """Reciprocal Rank Fusion for combining multiple ranking strategies.

    RRF formula: score = -sum(weight_i / (k + rank_i)) for each ranking strategy
    The negative is used because RRF produces higher scores for better results,
    but Chroma uses ascending order (lower scores = better results).

    Args:
        ranks: List of Rank expressions to fuse (must have at least one)
        k: Smoothing constant (default: 60, standard in literature)
        weights: Optional weights for each ranking strategy. If not provided,
                all ranks are weighted equally (weight=1.0 each).
        normalize: If True, normalize weights to sum to 1.0 (default: False).
                  When False, weights are used as-is for relative importance.
                  When True, weights are scaled so they sum to 1.0.

    Examples:
        # Note: metadata fields (like "sparse_embedding" below) are user-defined and can store any data.
        # The field name is just an example - use whatever name matches your metadata structure.
        # Basic RRF combining KNN rankings (equal weight)
        Rrf([
            Knn(query=[0.1, 0.2], return_rank=True),
            Knn(query=another_vector, key="custom_embedding", return_rank=True)  # Example metadata field
        ])

        # Weighted RRF with relative weights (not normalized)
        Rrf(
            ranks=[
                Knn(query=[0.1, 0.2], return_rank=True),
                Knn(query=another_vector, key="custom_embedding", return_rank=True)  # Example metadata field
            weights=[2.0, 1.0],  # First ranking is 2x more important
            k=100
        )

        # Weighted RRF with normalized weights
        Rrf(
            ranks=[
                Knn(query=[0.1, 0.2], return_rank=True),
                Knn(query=another_vector, key="custom_embedding", return_rank=True)  # Example metadata field
            ],
            weights=[3.0, 1.0],  # Will be normalized to [0.75, 0.25]
            normalize=True,
            k=100
        )
    """

    ranks: List[Rank]
    k: int = 60
    weights: Optional[List[float]] = None
    normalize: bool = False

    def to_dict(self) -> Dict[str, Any]:
        """Convert RRF to a composition of existing expression operators.

        Builds: -sum(weight_i / (k + rank_i)) for each rank
        Using Python's overloaded operators for cleaner code.
        """
        # Validate RRF parameters
        if not self.ranks:
            raise ValueError("RRF requires at least one rank")
        if self.k <= 0:
            raise ValueError(f"k must be positive, got {self.k}")

        # Validate weights if provided
        if self.weights is not None:
            if len(self.weights) != len(self.ranks):
                raise ValueError(
                    f"Number of weights ({len(self.weights)}) must match number of ranks ({len(self.ranks)})"
                )
            if any(w < 0.0 for w in self.weights):
                raise ValueError("All weights must be non-negative")

        # Populate weights with 1.0 if not provided
        weights = self.weights if self.weights else [1.0] * len(self.ranks)

        # Normalize weights if requested
        if self.normalize:
            weight_sum = sum(weights)
            if weight_sum == 0:
                raise ValueError("Sum of weights must be positive when normalize=True")
            weights = [w / weight_sum for w in weights]

        # Zip weights with ranks and build terms: weight / (k + rank)
        terms = [w / (self.k + rank) for w, rank in zip(weights, self.ranks)]

        # Sum all terms - guaranteed to have at least one
        rrf_sum: Rank = terms[0]
        for term in terms[1:]:
            rrf_sum = rrf_sum + term

        # Negate (RRF gives higher scores for better, Chroma needs lower for better)
        return (-rrf_sum).to_dict()


@dataclass
class Select:
    """Selection configuration for search results.

    Fields can be:
    - Key.DOCUMENT - Select document key (equivalent to Key("#document"))
    - Key.EMBEDDING - Select embedding key (equivalent to Key("#embedding"))
    - Key.SCORE - Select score key (equivalent to Key("#score"))
    - Any other string - Select specific metadata property

    Note: You can use K as an alias for Key for more concise code.

    Examples:
        # Select predefined keys using K alias (K is shorthand for Key)
        from chromadb.execution.expression import K
        Select(keys={K.DOCUMENT, K.SCORE})

        # Select specific metadata properties
        Select(keys={"title", "author", "date"})

        # Mixed selection
        Select(keys={K.DOCUMENT, "title", "author"})
    """

    keys: Set[Union[Key, str]] = field(default_factory=set)

    def to_dict(self) -> Dict[str, Any]:
        """Convert the Select to a dictionary for JSON serialization"""
        # Convert Key objects to their string values
        key_strings = []
        for k in self.keys:
            if isinstance(k, Key):
                key_strings.append(k.name)
            else:
                key_strings.append(k)
        # Remove duplicates while preserving order
        return {"keys": list(dict.fromkeys(key_strings))}

    @staticmethod
    def from_dict(data: Dict[str, Any]) -> "Select":
        """Create Select from dictionary.

        Examples:
        - {"keys": ["#document", "#score"]} -> Select(keys={Key.DOCUMENT, Key.SCORE})
        - {"keys": ["title", "author"]} -> Select(keys={"title", "author"})
        """
        if not isinstance(data, dict):
            raise TypeError(f"Expected dict for Select, got {type(data).__name__}")

        keys = data.get("keys", [])
        if not isinstance(keys, (list, tuple, set)):
            raise TypeError(
                f"Select keys must be a list/tuple/set, got {type(keys).__name__}"
            )

        # Validate and convert each key
        key_list = []
        for k in keys:
            if not isinstance(k, str):
                raise TypeError(f"Select key must be a string, got {type(k).__name__}")

            # Map special keys to Key instances
            if k == "#id":
                key_list.append(Key.ID)
            elif k == "#document":
                key_list.append(Key.DOCUMENT)
            elif k == "#embedding":
                key_list.append(Key.EMBEDDING)
            elif k == "#metadata":
                key_list.append(Key.METADATA)
            elif k == "#score":
                key_list.append(Key.SCORE)
            else:
                # Regular metadata field
                key_list.append(Key(k))

        # Check for unexpected keys in dict
        allowed_keys = {"keys"}
        unexpected_keys = set(data.keys()) - allowed_keys
        if unexpected_keys:
            raise ValueError(f"Unexpected keys in Select dict: {unexpected_keys}")

        # Convert to set while preserving the Key instances
        return Select(keys=set(key_list))
