import functools
from typing import (
    TYPE_CHECKING,
    Callable,
    Dict,
    Generic,
    Optional,
    Any,
    Set,
    TypeVar,
    Union,
    cast,
    List,
)
from chromadb.types import Metadata
import numpy as np
from uuid import UUID

from chromadb.api.types import (
    URI,
    Schema,
    SparseVectorIndexConfig,
    URIs,
    AddRequest,
    BaseRecordSet,
    CollectionMetadata,
    DataLoader,
    DeleteRequest,
    Embedding,
    Embeddings,
    FilterSet,
    GetRequest,
    PyEmbedding,
    Embeddable,
    GetResult,
    Include,
    Loadable,
    Document,
    Image,
    QueryRequest,
    QueryResult,
    IDs,
    EmbeddingFunction,
    SparseEmbeddingFunction,
    ID,
    OneOrMany,
    UpdateRequest,
    UpsertRequest,
    get_default_embeddable_record_set_fields,
    maybe_cast_one_to_many,
    normalize_base_record_set,
    normalize_insert_record_set,
    validate_base_record_set,
    validate_ids,
    validate_include,
    validate_insert_record_set,
    validate_metadata,
    validate_metadatas,
    validate_embedding_function,
    validate_sparse_embedding_function,
    validate_n_results,
    validate_record_set_contains_any,
    validate_record_set_for_embedding,
    validate_filter_set,
    DefaultEmbeddingFunction,
    EMBEDDING_KEY,
    DOCUMENT_KEY,
)
from chromadb.api.collection_configuration import (
    UpdateCollectionConfiguration,
    overwrite_collection_configuration,
    load_collection_configuration_from_json,
    CollectionConfiguration,
)

# TODO: We should rename the types in chromadb.types to be Models where
# appropriate. This will help to distinguish between manipulation objects
# which are essentially API views. And the actual data models which are
# stored / retrieved / transmitted.
from chromadb.types import Collection as CollectionModel, Where, WhereDocument
import logging

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from chromadb.api import ServerAPI, AsyncServerAPI

ClientT = TypeVar("ClientT", "ServerAPI", "AsyncServerAPI")

T = TypeVar("T")


def validation_context(name: str) -> Callable[[Callable[..., T]], Callable[..., T]]:
    """A decorator that wraps a method with a try-except block that catches
    exceptions and adds the method name to the error message. This allows us to
    provide more context when an error occurs, without rewriting validators.
    """

    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        @functools.wraps(func)
        def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
            try:
                return func(self, *args, **kwargs)
            except Exception as e:
                msg = f"{str(e)} in {name}."
                # add the rest of the args to the error message if they exist
                e.args = (msg,) + e.args[1:] if e.args else ()
                # raise the same error that was caught with the modified message
                raise

        return wrapper

    return decorator


class CollectionCommon(Generic[ClientT]):
    _model: CollectionModel
    _client: ClientT
    _embedding_function: Optional[EmbeddingFunction[Embeddable]]
    _data_loader: Optional[DataLoader[Loadable]]

    def __init__(
        self,
        client: ClientT,
        model: CollectionModel,
        embedding_function: Optional[
            EmbeddingFunction[Embeddable]
        ] = DefaultEmbeddingFunction(),  # type: ignore
        data_loader: Optional[DataLoader[Loadable]] = None,
    ):
        """Initializes a new instance of the Collection class."""

        self._client = client
        self._model = model

        # Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol
        if embedding_function is not None:
            validate_embedding_function(embedding_function)

        self._embedding_function = embedding_function
        self._data_loader = data_loader

    # Expose the model properties as read-only properties on the Collection class

    @property
    def id(self) -> UUID:
        return self._model.id

    @property
    def name(self) -> str:
        return self._model.name

    @property
    def configuration(self) -> CollectionConfiguration:
        return load_collection_configuration_from_json(self._model.configuration_json)

    @property
    def configuration_json(self) -> Dict[str, Any]:
        return self._model.configuration_json

    @property
    def schema(self) -> Optional[Schema]:
        return Schema.deserialize_from_json(
            self._model.serialized_schema if self._model.serialized_schema else {}
        )

    @property
    def metadata(self) -> CollectionMetadata:
        return cast(CollectionMetadata, self._model.metadata)

    @property
    def tenant(self) -> str:
        return self._model.tenant

    @property
    def database(self) -> str:
        return self._model.database

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, CollectionCommon):
            return False
        id_match = self.id == other.id
        name_match = self.name == other.name
        configuration_match = self.configuration_json == other.configuration_json
        schema_match = self.schema == other.schema
        metadata_match = self.metadata == other.metadata
        tenant_match = self.tenant == other.tenant
        database_match = self.database == other.database
        embedding_function_match = self._embedding_function == other._embedding_function
        data_loader_match = self._data_loader == other._data_loader
        return (
            id_match
            and name_match
            and configuration_match
            and schema_match
            and metadata_match
            and tenant_match
            and database_match
            and embedding_function_match
            and data_loader_match
        )

    def __repr__(self) -> str:
        return f"Collection(name={self.name})"

    def get_model(self) -> CollectionModel:
        return self._model

    @validation_context("add")
    def _validate_and_prepare_add_request(
        self,
        ids: OneOrMany[ID],
        embeddings: Optional[
            Union[
                OneOrMany[Embedding],
                OneOrMany[PyEmbedding],
            ]
        ],
        metadatas: Optional[OneOrMany[Metadata]],
        documents: Optional[OneOrMany[Document]],
        images: Optional[OneOrMany[Image]],
        uris: Optional[OneOrMany[URI]],
    ) -> AddRequest:
        # Unpack
        add_records = normalize_insert_record_set(
            ids=ids,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            images=images,
            uris=uris,
        )

        # Validate
        validate_insert_record_set(record_set=add_records)
        validate_record_set_contains_any(record_set=add_records, contains_any={"ids"})

        # Prepare
        if add_records["embeddings"] is None:
            validate_record_set_for_embedding(record_set=add_records)
            add_embeddings = self._embed_record_set(record_set=add_records)
        else:
            add_embeddings = add_records["embeddings"]

        add_metadatas = self._apply_sparse_embeddings_to_metadatas(
            add_records["metadatas"], add_records["documents"]
        )

        return AddRequest(
            ids=add_records["ids"],
            embeddings=add_embeddings,
            metadatas=add_metadatas,
            documents=add_records["documents"],
            uris=add_records["uris"],
        )

    @validation_context("get")
    def _validate_and_prepare_get_request(
        self,
        ids: Optional[OneOrMany[ID]],
        where: Optional[Where],
        where_document: Optional[WhereDocument],
        include: Include,
    ) -> GetRequest:
        # Unpack
        unpacked_ids: Optional[IDs] = maybe_cast_one_to_many(target=ids)
        filters = FilterSet(where=where, where_document=where_document)

        # Validate
        if unpacked_ids is not None:
            validate_ids(ids=unpacked_ids)

        validate_filter_set(filter_set=filters)
        validate_include(include=include, dissalowed=["distances"])

        if "data" in include and self._data_loader is None:
            raise ValueError(
                "You must set a data loader on the collection if loading from URIs."
            )

        # Prepare
        request_include = include
        # We need to include uris in the result from the API to load datas
        if "data" in include and "uris" not in include:
            request_include.append("uris")

        return GetRequest(
            ids=unpacked_ids,
            where=filters["where"],
            where_document=filters["where_document"],
            include=request_include,
        )

    @validation_context("query")
    def _validate_and_prepare_query_request(
        self,
        query_embeddings: Optional[
            Union[
                OneOrMany[Embedding],
                OneOrMany[PyEmbedding],
            ]
        ],
        query_texts: Optional[OneOrMany[Document]],
        query_images: Optional[OneOrMany[Image]],
        query_uris: Optional[OneOrMany[URI]],
        ids: Optional[OneOrMany[ID]],
        n_results: int,
        where: Optional[Where],
        where_document: Optional[WhereDocument],
        include: Include,
    ) -> QueryRequest:
        # Unpack
        query_records = normalize_base_record_set(
            embeddings=query_embeddings,
            documents=query_texts,
            images=query_images,
            uris=query_uris,
        )

        filter_ids = maybe_cast_one_to_many(ids)

        filters = FilterSet(
            where=where,
            where_document=where_document,
        )

        # Validate
        validate_base_record_set(record_set=query_records)
        validate_filter_set(filter_set=filters)
        validate_include(include=include)
        validate_n_results(n_results=n_results)

        # Prepare
        if query_records["embeddings"] is None:
            validate_record_set_for_embedding(record_set=query_records)
            request_embeddings = self._embed_record_set(
                record_set=query_records, is_query=True
            )
        else:
            request_embeddings = query_records["embeddings"]

        request_where = filters["where"]
        request_where_document = filters["where_document"]

        # We need to manually include uris in the result from the API to load datas
        request_include = include
        if "data" in request_include and "uris" not in request_include:
            request_include.append("uris")

        return QueryRequest(
            embeddings=request_embeddings,
            ids=filter_ids,
            where=request_where,
            where_document=request_where_document,
            include=request_include,
            n_results=n_results,
        )

    @validation_context("update")
    def _validate_and_prepare_update_request(
        self,
        ids: OneOrMany[ID],
        embeddings: Optional[
            Union[
                OneOrMany[Embedding],
                OneOrMany[PyEmbedding],
            ]
        ],
        metadatas: Optional[OneOrMany[Metadata]],
        documents: Optional[OneOrMany[Document]],
        images: Optional[OneOrMany[Image]],
        uris: Optional[OneOrMany[URI]],
    ) -> UpdateRequest:
        # Unpack
        update_records = normalize_insert_record_set(
            ids=ids,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            images=images,
            uris=uris,
        )

        # Validate
        validate_insert_record_set(record_set=update_records)

        # Prepare
        if update_records["embeddings"] is None:
            # TODO: Handle URI updates.
            if (
                update_records["documents"] is not None
                or update_records["images"] is not None
            ):
                validate_record_set_for_embedding(
                    update_records, embeddable_fields={"documents", "images"}
                )
                update_embeddings = self._embed_record_set(record_set=update_records)
            else:
                update_embeddings = None
        else:
            update_embeddings = update_records["embeddings"]

        update_metadatas = self._apply_sparse_embeddings_to_metadatas(
            update_records["metadatas"], update_records["documents"]
        )

        return UpdateRequest(
            ids=update_records["ids"],
            embeddings=update_embeddings,
            metadatas=update_metadatas,
            documents=update_records["documents"],
            uris=update_records["uris"],
        )

    @validation_context("upsert")
    def _validate_and_prepare_upsert_request(
        self,
        ids: OneOrMany[ID],
        embeddings: Optional[
            Union[
                OneOrMany[Embedding],
                OneOrMany[PyEmbedding],
            ]
        ] = None,
        metadatas: Optional[OneOrMany[Metadata]] = None,
        documents: Optional[OneOrMany[Document]] = None,
        images: Optional[OneOrMany[Image]] = None,
        uris: Optional[OneOrMany[URI]] = None,
    ) -> UpsertRequest:
        # Unpack
        upsert_records = normalize_insert_record_set(
            ids=ids,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            images=images,
            uris=uris,
        )

        # Validate
        validate_insert_record_set(record_set=upsert_records)

        # Prepare
        if upsert_records["embeddings"] is None:
            validate_record_set_for_embedding(
                record_set=upsert_records, embeddable_fields={"documents", "images"}
            )
            upsert_embeddings = self._embed_record_set(record_set=upsert_records)
        else:
            upsert_embeddings = upsert_records["embeddings"]

        upsert_metadatas = self._apply_sparse_embeddings_to_metadatas(
            upsert_records["metadatas"], upsert_records["documents"]
        )

        return UpsertRequest(
            ids=upsert_records["ids"],
            metadatas=upsert_metadatas,
            embeddings=upsert_embeddings,
            documents=upsert_records["documents"],
            uris=upsert_records["uris"],
        )

    @validation_context("delete")
    def _validate_and_prepare_delete_request(
        self,
        ids: Optional[IDs],
        where: Optional[Where],
        where_document: Optional[WhereDocument],
    ) -> DeleteRequest:
        if ids is None and where is None and where_document is None:
            raise ValueError(
                "At least one of ids, where, or where_document must be provided"
            )

        # Unpack
        if ids is not None:
            request_ids = cast(IDs, maybe_cast_one_to_many(ids))
        else:
            request_ids = None
        filters = FilterSet(where=where, where_document=where_document)

        # Validate
        if request_ids is not None:
            validate_ids(ids=request_ids)
        validate_filter_set(filter_set=filters)

        return DeleteRequest(
            ids=request_ids, where=where, where_document=where_document
        )

    def _transform_peek_response(self, response: GetResult) -> GetResult:
        if response["embeddings"] is not None:
            response["embeddings"] = np.array(response["embeddings"])

        return response

    def _transform_get_response(
        self, response: GetResult, include: Include
    ) -> GetResult:
        if (
            "data" in include
            and self._data_loader is not None
            and response["uris"] is not None
        ):
            response["data"] = self._data_loader(response["uris"])

        if "embeddings" in include:
            response["embeddings"] = np.array(response["embeddings"])

        # Remove URIs from the result if they weren't requested
        if "uris" not in include:
            response["uris"] = None

        return response

    def _transform_query_response(
        self, response: QueryResult, include: Include
    ) -> QueryResult:
        if (
            "data" in include
            and self._data_loader is not None
            and response["uris"] is not None
        ):
            response["data"] = [self._data_loader(uris) for uris in response["uris"]]

        if "embeddings" in include and response["embeddings"] is not None:
            response["embeddings"] = [
                np.array(embedding) for embedding in response["embeddings"]
            ]

        # Remove URIs from the result if they weren't requested
        if "uris" not in include:
            response["uris"] = None

        return response

    def _validate_modify_request(self, metadata: Optional[CollectionMetadata]) -> None:
        if metadata is not None:
            validate_metadata(metadata)
            if "hnsw:space" in metadata:
                raise ValueError(
                    "Changing the distance function of a collection once it is created is not supported currently."
                )

    def _update_model_after_modify_success(
        self,
        name: Optional[str],
        metadata: Optional[CollectionMetadata],
        configuration: Optional[UpdateCollectionConfiguration],
    ) -> None:
        if name:
            self._model["name"] = name
        if metadata:
            self._model["metadata"] = metadata
        if configuration:
            self._model.set_configuration(
                overwrite_collection_configuration(
                    self._model.get_configuration(), configuration
                )
            )

            # If schema exists, also update it with the configuration changes
            if self.schema:
                from chromadb.api.collection_configuration import (
                    update_schema_from_collection_configuration,
                )

                updated_schema = update_schema_from_collection_configuration(
                    self.schema, configuration
                )
                self._model["serialized_schema"] = updated_schema.serialize_to_json()

    def _get_sparse_embedding_targets(self) -> Dict[str, "SparseVectorIndexConfig"]:
        schema = self.schema
        if schema is None:
            return {}

        targets: Dict[str, "SparseVectorIndexConfig"] = {}
        for key, value_types in schema.keys.items():
            if value_types.sparse_vector is None:
                continue
            sparse_index = value_types.sparse_vector.sparse_vector_index
            if sparse_index is None or not sparse_index.enabled:
                continue
            config = sparse_index.config
            if config.embedding_function is None or config.source_key is None:
                continue
            targets[key] = config

        return targets

    def _apply_sparse_embeddings_to_metadatas(
        self,
        metadatas: Optional[List[Metadata]],
        documents: Optional[List[Document]] = None,
    ) -> Optional[List[Metadata]]:
        sparse_targets = self._get_sparse_embedding_targets()
        if not sparse_targets:
            return metadatas

        # If no metadatas provided, create empty dicts based on documents length
        if metadatas is None:
            if documents is None:
                return None
            metadatas = [{} for _ in range(len(documents))]

        # Create copies, converting None to empty dict
        updated_metadatas: List[Dict[str, Any]] = [
            dict(metadata) if metadata is not None else {} for metadata in metadatas
        ]

        documents_list = list(documents) if documents is not None else None

        for target_key, config in sparse_targets.items():
            source_key = config.source_key
            embedding_func = config.embedding_function
            if source_key is None or embedding_func is None:
                continue

            if not isinstance(embedding_func, SparseEmbeddingFunction):
                embedding_func = cast(SparseEmbeddingFunction[Any], embedding_func)
            validate_sparse_embedding_function(embedding_func)

            # Initialize collection lists for batch processing
            inputs: List[str] = []
            positions: List[int] = []

            # Handle special case: source_key is "#document"
            if source_key == DOCUMENT_KEY:
                if documents_list is None:
                    continue

                # Collect documents that need embedding
                for idx, metadata in enumerate(updated_metadatas):
                    # Skip if target already exists in metadata
                    if target_key in metadata:
                        continue

                    # Get document at this position
                    if idx < len(documents_list):
                        doc = documents_list[idx]
                        if isinstance(doc, str):
                            inputs.append(doc)
                            positions.append(idx)

                # Generate embeddings for all collected documents
                if len(inputs) == 0:
                    continue

                sparse_embeddings = self._sparse_embed(
                    input=inputs,
                    sparse_embedding_function=embedding_func,
                )

                if len(sparse_embeddings) != len(positions):
                    raise ValueError(
                        "Sparse embedding function returned unexpected number of embeddings."
                    )

                for position, embedding in zip(positions, sparse_embeddings):
                    updated_metadatas[position][target_key] = embedding

                continue  # Skip the metadata-based logic below

            # Handle normal case: source_key is a metadata field
            for idx, metadata in enumerate(updated_metadatas):
                if target_key in metadata:
                    continue

                source_value = metadata.get(source_key)
                if not isinstance(source_value, str):
                    continue

                inputs.append(source_value)
                positions.append(idx)

            if len(inputs) == 0:
                continue

            sparse_embeddings = self._sparse_embed(
                input=inputs,
                sparse_embedding_function=embedding_func,
            )

            if len(sparse_embeddings) != len(positions):
                raise ValueError(
                    "Sparse embedding function returned unexpected number of embeddings."
                )

            for position, embedding in zip(positions, sparse_embeddings):
                updated_metadatas[position][target_key] = embedding

        # Convert empty dicts back to None, validation requires non-empty dicts or None
        result_metadatas: List[Optional[Metadata]] = [
            metadata if metadata else None for metadata in updated_metadatas
        ]

        validate_metadatas(cast(List[Metadata], result_metadatas))
        return cast(List[Metadata], result_metadatas)

    def _embed_record_set(
        self,
        record_set: BaseRecordSet,
        embeddable_fields: Optional[Set[str]] = None,
        is_query: bool = False,
    ) -> Embeddings:
        if embeddable_fields is None:
            embeddable_fields = get_default_embeddable_record_set_fields()

        for field in embeddable_fields:
            if record_set[field] is not None:  # type: ignore[literal-required]
                # uris require special handling
                if field == "uris":
                    if self._data_loader is None:
                        raise ValueError(
                            "You must set a data loader on the collection if loading from URIs."
                        )
                    return self._embed(
                        input=self._data_loader(uris=cast(URIs, record_set[field])),  # type: ignore[literal-required]
                        is_query=is_query,
                    )
                else:
                    return self._embed(
                        input=record_set[field],  # type: ignore[literal-required]
                        is_query=is_query,
                    )
        raise ValueError(
            "Record does not contain any non-None fields that can be embedded."
            f"Embeddable Fields: {embeddable_fields}"
            f"Record Fields: {record_set}"
        )

    def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
        if self._embedding_function is not None and not isinstance(
            self._embedding_function, DefaultEmbeddingFunction
        ):
            if is_query:
                return self._embedding_function.embed_query(input=input)
            else:
                return self._embedding_function(input=input)

        config_ef = self.configuration.get("embedding_function")
        if config_ef is not None:
            if is_query:
                return config_ef.embed_query(input=input)
            else:
                return config_ef(input=input)
        schema = self.schema
        schema_embedding_function: Optional[EmbeddingFunction[Embeddable]] = None
        if schema is not None:
            override = schema.keys.get(EMBEDDING_KEY)
            if (
                override is not None
                and override.float_list is not None
                and override.float_list.vector_index is not None
                and override.float_list.vector_index.config.embedding_function
                is not None
            ):
                schema_embedding_function = cast(
                    EmbeddingFunction[Embeddable],
                    override.float_list.vector_index.config.embedding_function,
                )
            elif (
                schema.defaults.float_list is not None
                and schema.defaults.float_list.vector_index is not None
                and schema.defaults.float_list.vector_index.config.embedding_function
                is not None
            ):
                schema_embedding_function = cast(
                    EmbeddingFunction[Embeddable],
                    schema.defaults.float_list.vector_index.config.embedding_function,
                )

        if schema_embedding_function is not None:
            if is_query and hasattr(schema_embedding_function, "embed_query"):
                return schema_embedding_function.embed_query(input=input)
            return schema_embedding_function(input=input)
        if self._embedding_function is None:
            raise ValueError(
                "You must provide an embedding function to compute embeddings."
                "https://docs.trychroma.com/guides/embeddings"
            )
        if is_query:
            return self._embedding_function.embed_query(input=input)
        else:
            return self._embedding_function(input=input)

    def _sparse_embed(
        self,
        input: Any,
        sparse_embedding_function: SparseEmbeddingFunction[Any],
        is_query: bool = False,
    ) -> Any:
        if is_query:
            return sparse_embedding_function.embed_query(input=input)
        return sparse_embedding_function(input=input)

    def _embed_knn_string_queries(self, knn: Any) -> Any:
        """Embed string queries in Knn objects using the appropriate embedding function.

        Args:
            knn: A Knn object that may have a string query

        Returns:
            A Knn object with the string query replaced by an embedding

        Raises:
            ValueError: If the query is a string but no embedding function is available
        """
        from chromadb.execution.expression.operator import Knn

        if not isinstance(knn, Knn):
            return knn

        # If query is not a string, nothing to do
        if not isinstance(knn.query, str):
            return knn

        query_text = knn.query
        key = knn.key

        # Handle main embedding field
        if key == EMBEDDING_KEY:
            # Use the collection's main embedding function
            embedding = self._embed(input=[query_text], is_query=True)
            if not embedding or len(embedding) != 1:
                raise ValueError(
                    "Embedding function returned unexpected number of embeddings"
                )
            # Return a new Knn with the embedded query
            return Knn(
                query=embedding[0],
                key=knn.key,
                limit=knn.limit,
                default=knn.default,
                return_rank=knn.return_rank,
            )

        # Handle metadata field with potential sparse embedding
        schema = self.schema
        if schema is None or key not in schema.keys:
            raise ValueError(
                f"Cannot embed string query for key '{key}': "
                f"key not found in schema. Please provide an embedded vector or "
                f"configure an embedding function for this key in the schema."
            )

        value_type = schema.keys[key]

        # Check for sparse vector with embedding function
        if value_type.sparse_vector is not None:
            sparse_index = value_type.sparse_vector.sparse_vector_index
            if sparse_index is not None and sparse_index.enabled:
                sparse_config = sparse_index.config
                if sparse_config.embedding_function is not None:
                    embedding_func = sparse_config.embedding_function
                    if not isinstance(embedding_func, SparseEmbeddingFunction):
                        embedding_func = cast(
                            SparseEmbeddingFunction[Any], embedding_func
                        )
                    validate_sparse_embedding_function(embedding_func)

                    # Embed the query
                    sparse_embedding = self._sparse_embed(
                        input=[query_text],
                        sparse_embedding_function=embedding_func,
                        is_query=True,
                    )

                    if not sparse_embedding or len(sparse_embedding) != 1:
                        raise ValueError(
                            "Sparse embedding function returned unexpected number of embeddings"
                        )

                    # Return a new Knn with the sparse embedding
                    return Knn(
                        query=sparse_embedding[0],
                        key=knn.key,
                        limit=knn.limit,
                        default=knn.default,
                        return_rank=knn.return_rank,
                    )

        # Check for dense vector with embedding function (float_list)
        if value_type.float_list is not None:
            vector_index = value_type.float_list.vector_index
            if vector_index is not None and vector_index.enabled:
                dense_config = vector_index.config
                if dense_config.embedding_function is not None:
                    embedding_func = dense_config.embedding_function
                    validate_embedding_function(embedding_func)

                    # Embed the query using the schema's embedding function
                    try:
                        embeddings = embedding_func.embed_query(input=[query_text])
                    except AttributeError:
                        # Fallback if embed_query doesn't exist
                        embeddings = embedding_func([query_text])

                    if not embeddings or len(embeddings) != 1:
                        raise ValueError(
                            "Embedding function returned unexpected number of embeddings"
                        )

                    # Return a new Knn with the dense embedding
                    return Knn(
                        query=embeddings[0],
                        key=knn.key,
                        limit=knn.limit,
                        default=knn.default,
                        return_rank=knn.return_rank,
                    )

        raise ValueError(
            f"Cannot embed string query for key '{key}': "
            f"no embedding function configured for this key in the schema. "
            f"Please provide an embedded vector or configure an embedding function."
        )

    def _embed_rank_string_queries(self, rank: Any) -> Any:
        """Recursively embed string queries in Rank expressions.

        Args:
            rank: A Rank expression that may contain Knn objects with string queries

        Returns:
            A Rank expression with all string queries embedded
        """
        # Import here to avoid circular dependency
        from chromadb.execution.expression.operator import (
            Knn,
            Abs,
            Div,
            Exp,
            Log,
            Max,
            Min,
            Mul,
            Sub,
            Sum,
            Val,
            Rrf,
        )

        if rank is None:
            return None

        # Base case: Knn - embed if it has a string query
        if isinstance(rank, Knn):
            return self._embed_knn_string_queries(rank)

        # Base case: Val - no embedding needed
        if isinstance(rank, Val):
            return rank

        # Recursive cases: walk through child ranks
        if isinstance(rank, Abs):
            return Abs(self._embed_rank_string_queries(rank.rank))

        if isinstance(rank, Div):
            return Div(
                self._embed_rank_string_queries(rank.left),
                self._embed_rank_string_queries(rank.right),
            )

        if isinstance(rank, Exp):
            return Exp(self._embed_rank_string_queries(rank.rank))

        if isinstance(rank, Log):
            return Log(self._embed_rank_string_queries(rank.rank))

        if isinstance(rank, Max):
            return Max([self._embed_rank_string_queries(r) for r in rank.ranks])

        if isinstance(rank, Min):
            return Min([self._embed_rank_string_queries(r) for r in rank.ranks])

        if isinstance(rank, Mul):
            return Mul([self._embed_rank_string_queries(r) for r in rank.ranks])

        if isinstance(rank, Sub):
            return Sub(
                self._embed_rank_string_queries(rank.left),
                self._embed_rank_string_queries(rank.right),
            )

        if isinstance(rank, Sum):
            return Sum([self._embed_rank_string_queries(r) for r in rank.ranks])

        if isinstance(rank, Rrf):
            return Rrf(
                ranks=[self._embed_rank_string_queries(r) for r in rank.ranks],
                k=rank.k,
                weights=rank.weights,
                normalize=rank.normalize,
            )

        # Unknown rank type - return as is
        return rank

    def _embed_search_string_queries(self, search: Any) -> Any:
        """Embed string queries in a Search object.

        Args:
            search: A Search object that may contain Knn objects with string queries

        Returns:
            A Search object with all string queries embedded
        """
        # Import here to avoid circular dependency
        from chromadb.execution.expression.plan import Search

        if not isinstance(search, Search):
            return search

        # Embed the rank expression if it exists
        embedded_rank = self._embed_rank_string_queries(search._rank)

        # Create a new Search with the embedded rank
        return Search(
            where=search._where,
            rank=embedded_rank,
            limit=search._limit,
            select=search._select,
        )
