# type: ignore
import os
import shutil
import sys
import tempfile
import traceback
from datetime import datetime, timedelta
from typing import Any

import httpx
import numpy as np
import pytest

import chromadb
import chromadb.server.fastapi
from chromadb.api.fastapi import FastAPI
from chromadb.api.types import (
    Document,
    EmbeddingFunction,
    QueryResult,
    TYPE_KEY,
    SPARSE_VECTOR_TYPE_VALUE,
)
from chromadb.config import Settings
from chromadb.errors import (
    ChromaError,
    NotFoundError,
    InvalidArgumentError,
)
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction


@pytest.fixture
def persist_dir():
    return tempfile.mkdtemp()


@pytest.fixture
def local_persist_api(persist_dir):
    client = chromadb.Client(
        Settings(
            chroma_api_impl="chromadb.api.segment.SegmentAPI",
            chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
            chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
            chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
            chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
            allow_reset=True,
            is_persistent=True,
            persist_directory=persist_dir,
        ),
    )
    yield client
    client.clear_system_cache()
    if os.path.exists(persist_dir):
        shutil.rmtree(persist_dir, ignore_errors=True)


# https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached
@pytest.fixture
def local_persist_api_cache_bust(persist_dir):
    client = chromadb.Client(
        Settings(
            chroma_api_impl="chromadb.api.segment.SegmentAPI",
            chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
            chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
            chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
            chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
            allow_reset=True,
            is_persistent=True,
            persist_directory=persist_dir,
        ),
    )
    yield client
    client.clear_system_cache()
    if os.path.exists(persist_dir):
        shutil.rmtree(persist_dir, ignore_errors=True)


def approx_equal(a, b, tolerance=1e-6) -> bool:
    return abs(a - b) < tolerance


def vector_approx_equal(a, b, tolerance: float = 1e-6) -> bool:
    if len(a) != len(b):
        return False
    return all([approx_equal(a, b, tolerance) for a, b in zip(a, b)])


@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist_index_loading(api_fixture, request):
    client = request.getfixturevalue("local_persist_api")
    client.reset()
    collection = client.create_collection("test")
    collection.add(ids="id1", documents="hello")

    api2 = request.getfixturevalue("local_persist_api_cache_bust")
    collection = api2.get_collection("test")

    includes = ["embeddings", "documents", "metadatas", "distances"]
    nn = collection.query(
        query_texts="hello",
        n_results=1,
        include=["embeddings", "documents", "metadatas", "distances"],
    )
    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 1
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None


@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist_index_loading_embedding_function(api_fixture, request):
    class TestEF(EmbeddingFunction[Document]):
        def __call__(self, input):
            return [np.array([1, 2, 3]) for _ in range(len(input))]

        def __init__(self, *args: Any, **kwargs: Any) -> None:
            super().__init__(*args, **kwargs)

        def name(self) -> str:
            return "test"

        def build_from_config(self, config: dict[str, Any]) -> None:
            pass

        def get_config(self) -> dict[str, Any]:
            return {}

    client = request.getfixturevalue("local_persist_api")
    client.reset()
    collection = client.create_collection("test", embedding_function=TestEF())
    collection.add(ids="id1", documents="hello")

    client2 = request.getfixturevalue("local_persist_api_cache_bust")
    collection = client2.get_collection("test", embedding_function=TestEF())

    includes = ["embeddings", "documents", "metadatas", "distances"]
    nn = collection.query(
        query_texts="hello",
        n_results=1,
        include=includes,
    )
    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 1
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None


@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist_index_get_or_create_embedding_function(api_fixture, request):
    class TestEF(EmbeddingFunction[Document]):
        def __call__(self, input):
            return [np.array([1, 2, 3]) for _ in range(len(input))]

        def __init__(self, *args: Any, **kwargs: Any) -> None:
            super().__init__(*args, **kwargs)

        def name(self) -> str:
            return "test"

        def build_from_config(self, config: dict[str, Any]) -> None:
            pass

        def get_config(self) -> dict[str, Any]:
            return {}

    api = request.getfixturevalue("local_persist_api")
    api.reset()
    collection = api.get_or_create_collection("test", embedding_function=TestEF())
    collection.add(ids="id1", documents="hello")

    api2 = request.getfixturevalue("local_persist_api_cache_bust")
    collection = api2.get_or_create_collection("test", embedding_function=TestEF())

    includes = ["embeddings", "documents", "metadatas", "distances"]
    nn = collection.query(
        query_texts="hello",
        n_results=1,
        include=includes,
    )

    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 1
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None

    assert nn["ids"] == [["id1"]]
    assert nn["embeddings"][0][0].tolist() == [1, 2, 3]
    assert nn["documents"] == [["hello"]]
    assert nn["distances"] == [[0]]


@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist(api_fixture, request):
    client = request.getfixturevalue(api_fixture.__name__)

    client.reset()

    collection = client.create_collection("testspace")

    collection.add(**batch_records)

    assert collection.count() == 2

    client = request.getfixturevalue(api_fixture.__name__)
    collection = client.get_collection("testspace")
    assert collection.count() == 2

    client.delete_collection("testspace")

    client = request.getfixturevalue(api_fixture.__name__)
    assert client.list_collections() == []


def test_heartbeat(client):
    heartbeat_ns = client.heartbeat()
    assert isinstance(heartbeat_ns, int)

    heartbeat_s = heartbeat_ns // 10**9
    heartbeat = datetime.fromtimestamp(heartbeat_s)
    assert heartbeat > datetime.now() - timedelta(seconds=10)


def test_max_batch_size(client):
    batch_size = client.get_max_batch_size()
    assert batch_size > 0


def test_supports_base64_encoding(client):
    if not isinstance(client, FastAPI):
        pytest.skip("Not a FastAPI instance")

    client.reset()

    supports_base64_encoding = client.supports_base64_encoding()
    assert supports_base64_encoding is True


def test_supports_base64_encoding_legacy(client):
    if not isinstance(client, FastAPI):
        pytest.skip("Not a FastAPI instance")

    client.reset()

    # legacy server does not give back supports_base64_encoding
    client.pre_flight_checks = {
        "max_batch_size": 100,
    }

    assert client.supports_base64_encoding() is False
    assert client.get_max_batch_size() == 100


def test_pre_flight_checks(client):
    if not isinstance(client, FastAPI):
        pytest.skip("Not a FastAPI instance")

    resp = httpx.get(f"{client._api_url}/pre-flight-checks")
    assert resp.status_code == 200
    assert resp.json() is not None
    assert "max_batch_size" in resp.json().keys()
    assert "supports_base64_encoding" in resp.json().keys()


batch_records = {
    "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
    "ids": ["https://example.com/1", "https://example.com/2"],
}


def test_add(client):
    client.reset()

    collection = client.create_collection("testspace")

    collection.add(**batch_records)

    assert collection.count() == 2


def test_collection_add_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.add(**batch_records)


def test_get_or_create(client):
    client.reset()

    collection = client.create_collection("testspace")

    collection.add(**batch_records)

    assert collection.count() == 2

    with pytest.raises(Exception):
        collection = client.create_collection("testspace")

    collection = client.get_or_create_collection("testspace")

    assert collection.count() == 2


minimal_records = {
    "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
    "ids": ["https://example.com/1", "https://example.com/2"],
}


def test_add_minimal(client):
    client.reset()

    collection = client.create_collection("testspace")

    collection.add(**minimal_records)

    assert collection.count() == 2


def test_get_from_db(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**batch_records)
    includes = ["embeddings", "documents", "metadatas"]
    records = collection.get(include=includes)
    for key in records.keys():
        if (key in includes) or (key == "ids"):
            assert len(records[key]) == 2
        elif key == "included":
            assert set(records[key]) == set(includes)
        else:
            assert records[key] is None


def test_collection_get_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.get()


def test_reset_db(client):
    client.reset()

    collection = client.create_collection("testspace")
    collection.add(**batch_records)
    assert collection.count() == 2

    client.reset()
    assert len(client.list_collections()) == 0


def test_get_nearest_neighbors(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**batch_records)

    includes = ["embeddings", "documents", "metadatas", "distances"]
    nn = collection.query(
        query_embeddings=[1.1, 2.3, 3.2],
        n_results=1,
        include=includes,
    )
    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 1
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None

    nn = collection.query(
        query_embeddings=[[1.1, 2.3, 3.2]],
        n_results=1,
        include=includes,
    )
    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 1
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None

    nn = collection.query(
        query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]],
        n_results=1,
        include=includes,
    )
    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 2
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None


def test_delete(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**batch_records)
    assert collection.count() == 2

    with pytest.raises(Exception):
        collection.delete()


def test_delete_returns_none(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**batch_records)
    assert collection.count() == 2
    assert collection.delete(ids=batch_records["ids"]) is None


def test_delete_with_index(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**batch_records)
    assert collection.count() == 2
    collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1)


def test_collection_delete_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.delete(ids=["id1"])


def test_count(client):
    client.reset()
    collection = client.create_collection("testspace")
    assert collection.count() == 0
    collection.add(**batch_records)
    assert collection.count() == 2


def test_collection_count_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.count()


def test_modify(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.modify(name="testspace2")

    # collection name is modify
    assert collection.name == "testspace2"


def test_collection_modify_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.modify(name="test2")


def test_modify_error_on_existing_name(client):
    client.reset()

    client.create_collection("testspace")
    c2 = client.create_collection("testspace2")

    with pytest.raises(Exception):
        c2.modify(name="testspace")


def test_modify_warn_on_DF_change(client, caplog):
    client.reset()

    collection = client.create_collection("testspace")

    with pytest.raises(Exception, match="not supported"):
        collection.modify(metadata={"hnsw:space": "cosine"})


def test_metadata_cru(client):
    client.reset()
    metadata_a = {"a": 1, "b": 2}
    # Test create metadata
    collection = client.create_collection("testspace", metadata=metadata_a)
    assert collection.metadata is not None
    assert collection.metadata["a"] == 1
    assert collection.metadata["b"] == 2

    # Test get metadata
    collection = client.get_collection("testspace")
    assert collection.metadata is not None
    assert collection.metadata["a"] == 1
    assert collection.metadata["b"] == 2

    # Test modify metadata
    collection.modify(metadata={"a": 2, "c": 3})
    assert collection.metadata["a"] == 2
    assert collection.metadata["c"] == 3
    assert "b" not in collection.metadata

    # Test get after modify metadata
    collection = client.get_collection("testspace")
    assert collection.metadata is not None
    assert collection.metadata["a"] == 2
    assert collection.metadata["c"] == 3
    assert "b" not in collection.metadata

    # Test name exists get_or_create_metadata
    collection = client.get_or_create_collection("testspace")
    assert collection.metadata is not None
    assert collection.metadata["a"] == 2
    assert collection.metadata["c"] == 3

    # Test name exists create metadata
    collection = client.get_or_create_collection("testspace2")
    assert collection.metadata is None

    # Test list collections
    collections = client.list_collections()
    for collection in collections:
        if collection.name == "testspace":
            assert collection.metadata is not None
            assert collection.metadata["a"] == 2
            assert collection.metadata["c"] == 3
        elif collection.name == "testspace2":
            assert collection.metadata is None


def test_increment_index_on(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**batch_records)
    assert collection.count() == 2

    includes = ["embeddings", "documents", "metadatas", "distances"]
    # increment index
    nn = collection.query(
        query_embeddings=[[1.1, 2.3, 3.2]],
        n_results=1,
        include=includes,
    )
    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 1
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None


def test_add_a_collection(client):
    client.reset()
    client.create_collection("testspace")

    # get collection does not throw an error
    collection = client.get_collection("testspace")
    assert collection.name == "testspace"

    # get collection should throw an error if collection does not exist
    with pytest.raises(Exception):
        collection = client.get_collection("testspace2")


def test_error_includes_trace_id(http_client):
    http_client.reset()

    with pytest.raises(ChromaError) as error:
        http_client.get_collection("testspace2")

    assert error.value.trace_id is not None


def test_list_collections(client):
    client.reset()
    client.create_collection("testspace")
    client.create_collection("testspace2")

    # get collection does not throw an error
    collections = client.list_collections()
    assert len(collections) == 2


def test_reset(client):
    client.reset()
    client.create_collection("testspace")
    client.create_collection("testspace2")

    # get collection does not throw an error
    collections = client.list_collections()
    assert len(collections) == 2

    client.reset()
    collections = client.list_collections()
    assert len(collections) == 0


def test_peek(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**batch_records)
    assert collection.count() == 2

    # peek
    peek = collection.peek()
    print(peek)
    for key in peek.keys():
        if key in ["embeddings", "documents", "metadatas"] or key == "ids":
            assert len(peek[key]) == 2
        elif key == "included":
            assert set(peek[key]) == set(["embeddings", "metadatas", "documents"])
        else:
            assert peek[key] is None


def test_collection_peek_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.peek()


def test_collection_query_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.query(query_texts=["test"])


def test_collection_update_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.update(ids=["id1"], documents=["test"])


# TEST METADATA AND METADATA FILTERING
# region

metadata_records = {
    "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
    "ids": ["id1", "id2"],
    "metadatas": [
        {"int_value": 1, "string_value": "one", "float_value": 1.001},
        {"int_value": 2},
    ],
}


def test_metadata_add_get_int_float(client):
    client.reset()
    collection = client.create_collection("test_int")
    collection.add(**metadata_records)

    items = collection.get(ids=["id1", "id2"])
    assert items["metadatas"][0]["int_value"] == 1
    assert items["metadatas"][0]["float_value"] == 1.001
    assert items["metadatas"][1]["int_value"] == 2
    assert isinstance(items["metadatas"][0]["int_value"], int)
    assert isinstance(items["metadatas"][0]["float_value"], float)


def test_metadata_add_query_int_float(client):
    client.reset()
    collection = client.create_collection("test_int")
    collection.add(**metadata_records)

    items: QueryResult = collection.query(
        query_embeddings=[[1.1, 2.3, 3.2]], n_results=1
    )
    assert items["metadatas"] is not None
    assert items["metadatas"][0][0]["int_value"] == 1
    assert items["metadatas"][0][0]["float_value"] == 1.001
    assert isinstance(items["metadatas"][0][0]["int_value"], int)
    assert isinstance(items["metadatas"][0][0]["float_value"], float)


def test_metadata_get_where_string(client):
    client.reset()
    collection = client.create_collection("test_int")
    collection.add(**metadata_records)

    items = collection.get(where={"string_value": "one"})
    assert items["metadatas"][0]["int_value"] == 1
    assert items["metadatas"][0]["string_value"] == "one"


def test_metadata_get_where_int(client):
    client.reset()
    collection = client.create_collection("test_int")
    collection.add(**metadata_records)

    items = collection.get(where={"int_value": 1})
    assert items["metadatas"][0]["int_value"] == 1
    assert items["metadatas"][0]["string_value"] == "one"


def test_metadata_get_where_float(client):
    client.reset()
    collection = client.create_collection("test_int")
    collection.add(**metadata_records)

    items = collection.get(where={"float_value": 1.001})
    assert items["metadatas"][0]["int_value"] == 1
    assert items["metadatas"][0]["string_value"] == "one"
    assert items["metadatas"][0]["float_value"] == 1.001


def test_metadata_update_get_int_float(client):
    client.reset()
    collection = client.create_collection("test_int")
    collection.add(**metadata_records)

    collection.update(
        ids=["id1"],
        metadatas=[{"int_value": 2, "string_value": "two", "float_value": 2.002}],
    )
    items = collection.get(ids=["id1"])
    assert items["metadatas"][0]["int_value"] == 2
    assert items["metadatas"][0]["string_value"] == "two"
    assert items["metadatas"][0]["float_value"] == 2.002


bad_metadata_records = {
    "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
    "ids": ["id1", "id2"],
    "metadatas": [{"value": {"nested": "5"}}, {"value": [1, 2, 3]}],
}


def test_metadata_validation_add(client):
    client.reset()
    collection = client.create_collection("test_metadata_validation")
    with pytest.raises(ValueError, match="metadata"):
        collection.add(**bad_metadata_records)


def test_metadata_validation_update(client):
    client.reset()
    collection = client.create_collection("test_metadata_validation")
    collection.add(**metadata_records)
    with pytest.raises(ValueError, match="metadata"):
        collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}})


def test_where_validation_get(client):
    client.reset()
    collection = client.create_collection("test_where_validation")
    with pytest.raises(ValueError, match="where"):
        collection.get(where={"value": {"nested": "5"}})


def test_where_validation_query(client):
    client.reset()
    collection = client.create_collection("test_where_validation")
    with pytest.raises(ValueError, match="where"):
        collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}})


operator_records = {
    "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
    "ids": ["id1", "id2"],
    "metadatas": [
        {"int_value": 1, "string_value": "one", "float_value": 1.001},
        {"int_value": 2, "float_value": 2.002, "string_value": "two"},
    ],
}


def test_where_lt(client):
    client.reset()
    collection = client.create_collection("test_where_lt")
    collection.add(**operator_records)
    items = collection.get(where={"int_value": {"$lt": 2}})
    assert len(items["metadatas"]) == 1


def test_where_lte(client):
    client.reset()
    collection = client.create_collection("test_where_lte")
    collection.add(**operator_records)
    items = collection.get(where={"int_value": {"$lte": 2.0}})
    assert len(items["metadatas"]) == 2


def test_where_gt(client):
    client.reset()
    collection = client.create_collection("test_where_lte")
    collection.add(**operator_records)
    items = collection.get(where={"float_value": {"$gt": -1.4}})
    assert len(items["metadatas"]) == 2


def test_where_gte(client):
    client.reset()
    collection = client.create_collection("test_where_lte")
    collection.add(**operator_records)
    items = collection.get(where={"float_value": {"$gte": 2.002}})
    assert len(items["metadatas"]) == 1


def test_where_ne_string(client):
    client.reset()
    collection = client.create_collection("test_where_lte")
    collection.add(**operator_records)
    items = collection.get(where={"string_value": {"$ne": "two"}})
    assert len(items["metadatas"]) == 1


def test_where_ne_eq_number(client):
    client.reset()
    collection = client.create_collection("test_where_lte")
    collection.add(**operator_records)
    items = collection.get(where={"int_value": {"$ne": 1}})
    assert len(items["metadatas"]) == 1
    items = collection.get(where={"float_value": {"$eq": 2.002}})
    assert len(items["metadatas"]) == 1


def test_where_valid_operators(client):
    client.reset()
    collection = client.create_collection("test_where_valid_operators")
    collection.add(**operator_records)
    with pytest.raises(ValueError):
        collection.get(where={"int_value": {"$invalid": 2}})

    with pytest.raises(ValueError):
        collection.get(where={"int_value": {"$lt": "2"}})

    with pytest.raises(ValueError):
        collection.get(where={"int_value": {"$lt": 2, "$gt": 1}})

    # Test invalid $and, $or
    with pytest.raises(ValueError):
        collection.get(where={"$and": {"int_value": {"$lt": 2}}})

    with pytest.raises(ValueError):
        collection.get(
            where={"int_value": {"$lt": 2}, "$or": {"int_value": {"$gt": 1}}}
        )

    with pytest.raises(ValueError):
        collection.get(
            where={"$gt": [{"int_value": {"$lt": 2}}, {"int_value": {"$gt": 1}}]}
        )

    with pytest.raises(ValueError):
        collection.get(where={"$or": [{"int_value": {"$lt": 2}}]})

    with pytest.raises(ValueError):
        collection.get(where={"$or": []})

    with pytest.raises(ValueError):
        collection.get(where={"a": {"$contains": "test"}})

    with pytest.raises(ValueError):
        collection.get(
            where={
                "$or": [
                    {"a": {"$contains": "first"}},  # invalid
                    {"$contains": "second"},  # valid
                ]
            }
        )


# TODO: Define the dimensionality of these embeddingds in terms of the default record
bad_dimensionality_records = {
    "embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]],
    "ids": ["id1", "id2"],
}

bad_dimensionality_query = {
    "query_embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]],
}

bad_number_of_results_query = {
    "query_embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
    "n_results": 100,
}


def test_dimensionality_validation_add(client):
    client.reset()
    collection = client.create_collection("test_dimensionality_validation")
    collection.add(**minimal_records)

    with pytest.raises(Exception) as e:
        collection.add(**bad_dimensionality_records)
    assert "dimension" in str(e.value)


def test_dimensionality_validation_query(client):
    client.reset()
    collection = client.create_collection("test_dimensionality_validation_query")
    collection.add(**minimal_records)

    with pytest.raises(Exception) as e:
        collection.query(**bad_dimensionality_query)
    assert "dimension" in str(e.value)


def test_query_document_valid_operators(client):
    client.reset()
    collection = client.create_collection("test_where_valid_operators")
    collection.add(**operator_records)
    with pytest.raises(ValueError, match="where document"):
        collection.get(where_document={"$lt": {"$nested": 2}})

    with pytest.raises(ValueError, match="where document"):
        collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2})

    with pytest.raises(ValueError, match="where document"):
        collection.get(where_document={"$contains": []})

    # Test invalid $contains
    with pytest.raises(ValueError, match="where document"):
        collection.get(where_document={"$contains": {"text": "hello"}})

    # Test invalid $not_contains
    with pytest.raises(ValueError, match="where document"):
        collection.get(where_document={"$not_contains": {"text": "hello"}})

    # Test invalid $and, $or
    with pytest.raises(ValueError):
        collection.get(where_document={"$and": {"$unsupported": "doc"}})

    with pytest.raises(ValueError):
        collection.get(
            where_document={"$or": [{"$unsupported": "doc"}, {"$unsupported": "doc"}]}
        )

    with pytest.raises(ValueError):
        collection.get(where_document={"$or": [{"$contains": "doc"}]})

    with pytest.raises(ValueError):
        collection.get(where_document={"$or": []})

    with pytest.raises(ValueError):
        collection.get(
            where_document={
                "$or": [{"$and": [{"$contains": "doc"}]}, {"$contains": "doc"}]
            }
        )


contains_records = {
    "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
    "documents": ["this is doc1 and it's great!", "doc2 is also great!"],
    "ids": ["id1", "id2"],
    "metadatas": [
        {"int_value": 1, "string_value": "one", "float_value": 1.001},
        {"int_value": 2, "float_value": 2.002, "string_value": "two"},
    ],
}


def test_get_where_document(client):
    client.reset()
    collection = client.create_collection("test_get_where_document")
    collection.add(**contains_records)

    items = collection.get(where_document={"$contains": "doc1"})
    assert len(items["metadatas"]) == 1

    items = collection.get(where_document={"$contains": "great"})
    assert len(items["metadatas"]) == 2

    items = collection.get(where_document={"$contains": "bad"})
    assert len(items["metadatas"]) == 0


def test_query_where_document(client):
    client.reset()
    collection = client.create_collection("test_query_where_document")
    collection.add(**contains_records)

    items = collection.query(
        query_embeddings=[1, 0, 0], where_document={"$contains": "doc1"}, n_results=1
    )
    assert len(items["metadatas"][0]) == 1

    items = collection.query(
        query_embeddings=[0, 0, 0], where_document={"$contains": "great"}, n_results=2
    )
    assert len(items["metadatas"][0]) == 2

    with pytest.raises(Exception) as e:
        items = collection.query(
            query_embeddings=[0, 0, 0], where_document={"$contains": "bad"}, n_results=1
        )
        assert "datapoints" in str(e.value)


def test_delete_where_document(client):
    client.reset()
    collection = client.create_collection("test_delete_where_document")
    collection.add(**contains_records)

    collection.delete(where_document={"$contains": "doc1"})
    assert collection.count() == 1

    collection.delete(where_document={"$contains": "bad"})
    assert collection.count() == 1

    collection.delete(where_document={"$contains": "great"})
    assert collection.count() == 0


logical_operator_records = {
    "embeddings": [
        [1.1, 2.3, 3.2],
        [1.2, 2.24, 3.2],
        [1.3, 2.25, 3.2],
        [1.4, 2.26, 3.2],
    ],
    "ids": ["id1", "id2", "id3", "id4"],
    "metadatas": [
        {"int_value": 1, "string_value": "one", "float_value": 1.001, "is": "doc"},
        {"int_value": 2, "float_value": 2.002, "string_value": "two", "is": "doc"},
        {"int_value": 3, "float_value": 3.003, "string_value": "three", "is": "doc"},
        {"int_value": 4, "float_value": 4.004, "string_value": "four", "is": "doc"},
    ],
    "documents": [
        "this document is first and great",
        "this document is second and great",
        "this document is third and great",
        "this document is fourth and great",
    ],
}


def test_where_logical_operators(client):
    client.reset()
    collection = client.create_collection("test_logical_operators")
    collection.add(**logical_operator_records)

    items = collection.get(
        where={
            "$and": [
                {"$or": [{"int_value": {"$gte": 3}}, {"float_value": {"$lt": 1.9}}]},
                {"is": "doc"},
            ]
        }
    )
    assert len(items["metadatas"]) == 3

    items = collection.get(
        where={
            "$or": [
                {
                    "$and": [
                        {"int_value": {"$eq": 3}},
                        {"string_value": {"$eq": "three"}},
                    ]
                },
                {
                    "$and": [
                        {"int_value": {"$eq": 4}},
                        {"string_value": {"$eq": "four"}},
                    ]
                },
            ]
        }
    )
    assert len(items["metadatas"]) == 2

    items = collection.get(
        where={
            "$and": [
                {
                    "$or": [
                        {"int_value": {"$eq": 1}},
                        {"string_value": {"$eq": "two"}},
                    ]
                },
                {
                    "$or": [
                        {"int_value": {"$eq": 2}},
                        {"string_value": {"$eq": "one"}},
                    ]
                },
            ]
        }
    )
    assert len(items["metadatas"]) == 2


def test_where_document_logical_operators(client):
    client.reset()
    collection = client.create_collection("test_document_logical_operators")
    collection.add(**logical_operator_records)

    items = collection.get(
        where_document={
            "$and": [
                {"$contains": "first"},
                {"$contains": "doc"},
            ]
        }
    )
    assert len(items["metadatas"]) == 1

    items = collection.get(
        where_document={
            "$or": [
                {"$contains": "first"},
                {"$contains": "second"},
            ]
        }
    )
    assert len(items["metadatas"]) == 2

    items = collection.get(
        where_document={
            "$or": [
                {"$contains": "first"},
                {"$contains": "second"},
            ]
        },
        where={
            "int_value": {"$ne": 2},
        },
    )
    assert len(items["metadatas"]) == 1


# endregion

records = {
    "embeddings": [[0, 0, 0], [1.2, 2.24, 3.2]],
    "ids": ["id1", "id2"],
    "metadatas": [
        {"int_value": 1, "string_value": "one", "float_value": 1.001},
        {"int_value": 2},
    ],
    "documents": ["this document is first", "this document is second"],
}


def test_query_include(client):
    client.reset()
    collection = client.create_collection("test_query_include")
    collection.add(**records)

    include = ["metadatas", "documents", "distances"]
    items = collection.query(
        query_embeddings=[0, 0, 0],
        include=include,
        n_results=1,
    )
    assert items["embeddings"] is None
    assert items["ids"][0][0] == "id1"
    assert items["metadatas"][0][0]["int_value"] == 1
    assert set(items["included"]) == set(include)

    include = ["embeddings", "documents", "distances"]
    items = collection.query(
        query_embeddings=[0, 0, 0],
        include=include,
        n_results=1,
    )
    assert items["metadatas"] is None
    assert items["ids"][0][0] == "id1"
    assert set(items["included"]) == set(include)

    items = collection.query(
        query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
        include=[],
        n_results=2,
    )
    assert items["documents"] is None
    assert items["metadatas"] is None
    assert items["embeddings"] is None
    assert items["distances"] is None
    assert items["ids"][0][0] == "id1"
    assert items["ids"][0][1] == "id2"


def test_get_include(client):
    client.reset()
    collection = client.create_collection("test_get_include")
    collection.add(**records)

    include = ["metadatas", "documents"]
    items = collection.get(include=include, where={"int_value": 1})
    assert items["embeddings"] is None
    assert items["ids"][0] == "id1"
    assert items["metadatas"][0]["int_value"] == 1
    assert items["documents"][0] == "this document is first"
    assert set(items["included"]) == set(include)

    include = ["embeddings", "documents"]
    items = collection.get(include=include)
    assert items["metadatas"] is None
    assert items["ids"][0] == "id1"
    assert approx_equal(items["embeddings"][1][0], 1.2)
    assert set(items["included"]) == set(include)

    items = collection.get(include=[])
    assert items["documents"] is None
    assert items["metadatas"] is None
    assert items["embeddings"] is None
    assert items["ids"][0] == "id1"
    assert items["included"] == []

    with pytest.raises(ValueError, match="include"):
        items = collection.get(include=["metadatas", "undefined"])

    with pytest.raises(ValueError, match="include"):
        items = collection.get(include=None)


# make sure query results are returned in the right order


def test_query_order(client):
    client.reset()
    collection = client.create_collection("test_query_order")
    collection.add(**records)

    items = collection.query(
        query_embeddings=[1.2, 2.24, 3.2],
        include=["metadatas", "documents", "distances"],
        n_results=2,
    )

    assert items["documents"][0][0] == "this document is second"
    assert items["documents"][0][1] == "this document is first"


# test to make sure add, get, delete error on invalid id input


def test_invalid_id(client):
    client.reset()
    collection = client.create_collection("test_invalid_id")
    # Add with non-string id
    with pytest.raises(ValueError) as e:
        collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}])
    assert "ID" in str(e.value)

    # Get with non-list id
    with pytest.raises(ValueError) as e:
        collection.get(ids=1)
    assert "ID" in str(e.value)

    # Delete with malformed ids
    with pytest.raises(ValueError) as e:
        collection.delete(ids=["valid", 0])
    assert "ID" in str(e.value)


def test_index_params(client):
    EPS = 1e-12
    # first standard add
    client.reset()
    collection = client.create_collection(name="test_index_params")
    collection.add(**records)
    items = collection.query(
        query_embeddings=[0.6, 1.12, 1.6],
        n_results=1,
    )
    assert items["distances"][0][0] > 4

    # cosine
    client.reset()
    collection = client.create_collection(
        name="test_index_params",
        metadata={"hnsw:space": "cosine", "hnsw:construction_ef": 20, "hnsw:M": 5},
    )
    collection.add(**records)
    items = collection.query(
        query_embeddings=[0.6, 1.12, 1.6],
        n_results=1,
    )
    assert items["distances"][0][0] > 0 - EPS
    assert items["distances"][0][0] < 1 + EPS

    # ip
    client.reset()
    collection = client.create_collection(
        name="test_index_params", metadata={"hnsw:space": "ip"}
    )
    collection.add(**records)
    items = collection.query(
        query_embeddings=[0.6, 1.12, 1.6],
        n_results=1,
    )
    assert items["distances"][0][0] < -5


def test_invalid_index_params(client):
    client.reset()

    with pytest.raises(InvalidArgumentError):
        collection = client.create_collection(
            name="test_index_params", metadata={"hnsw:space": "foobar"}
        )
        collection.add(**records)


def test_persist_index_loading_params(client, request):
    client = request.getfixturevalue("local_persist_api")
    client.reset()
    collection = client.create_collection(
        "test",
        metadata={"hnsw:space": "ip"},
    )
    collection.add(ids="id1", documents="hello")

    api2 = request.getfixturevalue("local_persist_api_cache_bust")
    collection = api2.get_collection(
        "test",
    )

    assert collection.metadata["hnsw:space"] == "ip"
    includes = ["embeddings", "documents", "metadatas", "distances"]
    nn = collection.query(
        query_texts="hello",
        n_results=1,
        include=includes,
    )
    for key in nn.keys():
        if (key in includes) or (key == "ids"):
            assert len(nn[key]) == 1
        elif key == "included":
            assert set(nn[key]) == set(includes)
        else:
            assert nn[key] is None


def test_add_large(client):
    client.reset()

    collection = client.create_collection("testspace")

    # Test adding a large number of records
    large_records = np.random.rand(2000, 512).astype(np.float32).tolist()

    collection.add(
        embeddings=large_records,
        ids=[f"http://example.com/{i}" for i in range(len(large_records))],
    )

    assert collection.count() == len(large_records)


# test get_version
def test_get_version(client):
    client.reset()
    version = client.get_version()

    # assert version matches the pattern x.y.z
    import re

    assert re.match(r"\d+\.\d+\.\d+", version)


# test delete_collection
def test_delete_collection(client):
    client.reset()
    collection = client.create_collection("test_delete_collection")
    collection.add(**records)

    assert len(client.list_collections()) == 1
    client.delete_collection("test_delete_collection")
    assert len(client.list_collections()) == 0


# test default embedding function
def test_default_embedding():
    embedding_function = DefaultEmbeddingFunction()
    docs = ["this is a test" for _ in range(64)]
    embeddings = embedding_function(docs)
    assert len(embeddings) == 64


def test_multiple_collections(client):
    embeddings1 = np.random.rand(10, 512).astype(np.float32).tolist()
    embeddings2 = np.random.rand(10, 512).astype(np.float32).tolist()
    ids1 = [f"http://example.com/1/{i}" for i in range(len(embeddings1))]
    ids2 = [f"http://example.com/2/{i}" for i in range(len(embeddings2))]

    client.reset()
    coll1 = client.create_collection("coll1")
    coll1.add(embeddings=embeddings1, ids=ids1)

    coll2 = client.create_collection("coll2")
    coll2.add(embeddings=embeddings2, ids=ids2)

    assert len(client.list_collections()) == 2
    assert coll1.count() == len(embeddings1)
    assert coll2.count() == len(embeddings2)

    results1 = coll1.query(query_embeddings=embeddings1[0], n_results=1)
    results2 = coll2.query(query_embeddings=embeddings2[0], n_results=1)

    # progressively check the results are what we expect so we can debug when/if flakes happen
    assert len(results1["ids"]) > 0
    assert len(results2["ids"]) > 0
    assert len(results1["ids"][0]) > 0
    assert len(results2["ids"][0]) > 0

    assert results1["ids"][0][0] == ids1[0]
    assert results2["ids"][0][0] == ids2[0]


def test_update_query(client):
    client.reset()
    collection = client.create_collection("test_update_query")
    collection.add(**records)

    updated_records = {
        "ids": [records["ids"][0]],
        "embeddings": [[0.1, 0.2, 0.3]],
        "documents": ["updated document"],
        "metadatas": [{"foo": "bar"}],
    }

    collection.update(**updated_records)

    # test query
    results = collection.query(
        query_embeddings=updated_records["embeddings"],
        n_results=1,
        include=["embeddings", "documents", "metadatas"],
    )
    assert len(results["ids"][0]) == 1
    assert results["ids"][0][0] == updated_records["ids"][0]
    assert results["documents"][0][0] == updated_records["documents"][0]
    assert results["metadatas"][0][0]["foo"] == "bar"
    assert vector_approx_equal(
        results["embeddings"][0][0], updated_records["embeddings"][0]
    )


def test_get_nearest_neighbors_where_n_results_more_than_element(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**records)

    includes = ["embeddings", "documents", "metadatas", "distances"]
    results = collection.query(
        query_embeddings=[[1.1, 2.3, 3.2]],
        n_results=5,
        include=includes,
    )
    for key in results.keys():
        if key in includes or key == "ids":
            assert len(results[key][0]) == 2
        elif key == "included":
            assert set(results[key]) == set(includes)
        else:
            assert results[key] is None


def test_invalid_n_results_param(client):
    client.reset()
    collection = client.create_collection("testspace")
    collection.add(**records)
    with pytest.raises(TypeError) as exc:
        collection.query(
            query_embeddings=[[1.1, 2.3, 3.2]],
            n_results=-1,
            include=["embeddings", "documents", "metadatas", "distances"],
        )
    assert "Number of requested results -1, cannot be negative, or zero." in str(
        exc.value
    )
    assert exc.type == TypeError

    with pytest.raises(ValueError) as exc:
        collection.query(
            query_embeddings=[[1.1, 2.3, 3.2]],
            n_results="one",
            include=["embeddings", "documents", "metadatas", "distances"],
        )
    assert "int" in str(exc.value)
    assert exc.type == ValueError


initial_records = {
    "embeddings": [[0, 0, 0], [1.2, 2.24, 3.2], [2.2, 3.24, 4.2]],
    "ids": ["id1", "id2", "id3"],
    "metadatas": [
        {"int_value": 1, "string_value": "one", "float_value": 1.001},
        {"int_value": 2},
        {"string_value": "three"},
    ],
    "documents": [
        "this document is first",
        "this document is second",
        "this document is third",
    ],
}

new_records = {
    "embeddings": [[3.0, 3.0, 1.1], [3.2, 4.24, 5.2]],
    "ids": ["id1", "id4"],
    "metadatas": [
        {"int_value": 1, "string_value": "one_of_one", "float_value": 1.001},
        {"int_value": 4},
    ],
    "documents": [
        "this document is even more first",
        "this document is new and fourth",
    ],
}


def test_upsert(client):
    client.reset()
    collection = client.create_collection("test")

    collection.add(**initial_records)
    assert collection.count() == 3

    collection.upsert(**new_records)
    assert collection.count() == 4

    get_result = collection.get(
        include=["embeddings", "metadatas", "documents"], ids=new_records["ids"][0]
    )
    assert vector_approx_equal(
        get_result["embeddings"][0], new_records["embeddings"][0]
    )
    assert get_result["metadatas"][0] == new_records["metadatas"][0]
    assert get_result["documents"][0] == new_records["documents"][0]

    query_result = collection.query(
        query_embeddings=get_result["embeddings"],
        n_results=1,
        include=["embeddings", "metadatas", "documents"],
    )
    assert vector_approx_equal(
        query_result["embeddings"][0][0], new_records["embeddings"][0]
    )
    assert query_result["metadatas"][0][0] == new_records["metadatas"][0]
    assert query_result["documents"][0][0] == new_records["documents"][0]

    collection.delete(ids=initial_records["ids"][2])
    collection.upsert(
        ids=initial_records["ids"][2],
        embeddings=[[1.1, 0.99, 2.21]],
        metadatas=[{"string_value": "a new string value"}],
    )
    assert collection.count() == 4

    get_result = collection.get(
        include=["embeddings", "metadatas", "documents"], ids=["id3"]
    )
    assert vector_approx_equal(get_result["embeddings"][0], [1.1, 0.99, 2.21])
    assert get_result["metadatas"][0] == {"string_value": "a new string value"}
    assert get_result["documents"][0] is None


def test_collection_upsert_with_invalid_collection_throws(client):
    client.reset()
    collection = client.create_collection("test")
    client.delete_collection("test")

    with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
        collection.upsert(**initial_records)


# test to make sure add, query, update, upsert error on invalid embeddings input


def test_invalid_embeddings(client):
    client.reset()
    collection = client.create_collection("test_invalid_embeddings")

    # Add with string embeddings
    invalid_records = {
        "embeddings": [["0", "0", "0"], ["1.2", "2.24", "3.2"]],
        "ids": ["id1", "id2"],
    }
    with pytest.raises(ValueError) as e:
        collection.add(**invalid_records)
    assert "embedding" in str(e.value)

    # Query with invalid embeddings
    with pytest.raises(ValueError) as e:
        collection.query(
            query_embeddings=[["1.1", "2.3", "3.2"]],
            n_results=1,
        )
    assert "embedding" in str(e.value)

    # Update with invalid embeddings
    invalid_records = {
        "embeddings": [[[0], [0], [0]], [[1.2], [2.24], [3.2]]],
        "ids": ["id1", "id2"],
    }
    with pytest.raises(ValueError) as e:
        collection.update(**invalid_records)
    assert "embedding" in str(e.value)

    # Upsert with invalid embeddings
    invalid_records = {
        "embeddings": [[[1.1, 2.3, 3.2]], [[1.2, 2.24, 3.2]]],
        "ids": ["id1", "id2"],
    }
    with pytest.raises(ValueError) as e:
        collection.upsert(**invalid_records)
    assert "embedding" in str(e.value)


# test to make sure update shows exception for bad dimensionality


def test_dimensionality_exception_update(client):
    client.reset()
    collection = client.create_collection("test_dimensionality_update_exception")
    collection.add(**minimal_records)

    with pytest.raises(Exception) as e:
        collection.update(**bad_dimensionality_records)
    assert "dimension" in str(e.value)


# test to make sure upsert shows exception for bad dimensionality


def test_dimensionality_exception_upsert(client):
    client.reset()
    collection = client.create_collection("test_dimensionality_upsert_exception")
    collection.add(**minimal_records)

    with pytest.raises(Exception) as e:
        collection.upsert(**bad_dimensionality_records)
    assert "dimension" in str(e.value)


# this may be flaky on windows, so we rerun it
@pytest.mark.flaky(reruns=3, condition=sys.platform.startswith("win32"))
def test_ssl_self_signed(client_ssl):
    if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"):
        pytest.skip("Skipping test for integration test")
    client_ssl.heartbeat()


# this may be flaky on windows, so we rerun it
@pytest.mark.flaky(reruns=3, condition=sys.platform.startswith("win32"))
def test_ssl_self_signed_without_ssl_verify(client_ssl):
    if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"):
        pytest.skip("Skipping test for integration test")
    client_ssl.heartbeat()
    _port = client_ssl._server._settings.chroma_server_http_port
    with pytest.raises(ValueError) as e:
        chromadb.HttpClient(ssl=True, port=_port)
    stack_trace = traceback.format_exception(
        type(e.value), e.value, e.value.__traceback__
    )
    client_ssl.clear_system_cache()
    assert "CERTIFICATE_VERIFY_FAILED" in "".join(stack_trace)


def test_query_id_filtering_small_dataset(client):
    client.reset()
    collection = client.create_collection("test_query_id_filtering_small")

    num_vectors = 100
    dim = 512
    small_records = np.random.rand(100, 512).astype(np.float32).tolist()
    ids = [f"{i}" for i in range(num_vectors)]

    collection.add(
        embeddings=small_records,
        ids=ids,
    )

    query_ids = [f"{i}" for i in range(0, num_vectors, 10)]
    query_embedding = np.random.rand(dim).astype(np.float32).tolist()
    results = collection.query(
        query_embeddings=query_embedding,
        ids=query_ids,
        n_results=num_vectors,
        include=[],
    )

    all_returned_ids = [item for sublist in results["ids"] for item in sublist]
    assert all(id in query_ids for id in all_returned_ids)


def test_query_id_filtering_medium_dataset(client):
    client.reset()
    collection = client.create_collection("test_query_id_filtering_medium")

    num_vectors = 1000
    dim = 512
    medium_records = np.random.rand(num_vectors, dim).astype(np.float32).tolist()
    ids = [f"{i}" for i in range(num_vectors)]

    collection.add(
        embeddings=medium_records,
        ids=ids,
    )

    query_ids = [f"{i}" for i in range(0, num_vectors, 10)]

    query_embedding = np.random.rand(dim).astype(np.float32).tolist()
    results = collection.query(
        query_embeddings=query_embedding,
        ids=query_ids,
        n_results=num_vectors,
        include=[],
    )

    all_returned_ids = [item for sublist in results["ids"] for item in sublist]
    assert all(id in query_ids for id in all_returned_ids)

    multi_query_embeddings = [
        np.random.rand(dim).astype(np.float32).tolist() for _ in range(3)
    ]
    multi_results = collection.query(
        query_embeddings=multi_query_embeddings,
        ids=query_ids,
        n_results=10,
        include=[],
    )

    for result_set in multi_results["ids"]:
        assert all(id in query_ids for id in result_set)


def test_query_id_filtering_e2e(client):
    client.reset()
    collection = client.create_collection("test_query_id_filtering_e2e")

    dim = 512
    num_vectors = 100
    embeddings = np.random.rand(num_vectors, dim).astype(np.float32).tolist()
    ids = [f"{i}" for i in range(num_vectors)]
    metadatas = [{"index": i} for i in range(num_vectors)]

    collection.add(
        embeddings=embeddings,
        ids=ids,
        metadatas=metadatas,
    )

    ids_to_delete = [f"{i}" for i in range(10, 30)]
    collection.delete(ids=ids_to_delete)

    # modify some existing ids, and add some new ones to check query returns updated metadata
    ids_to_upsert_existing = [f"{i}" for i in range(30, 50)]
    new_num_vectors = num_vectors + 20
    ids_to_upsert_new = [f"{i}" for i in range(num_vectors, new_num_vectors)]

    upsert_embeddings = (
        np.random.rand(len(ids_to_upsert_existing) + len(ids_to_upsert_new), dim)
        .astype(np.float32)
        .tolist()
    )
    upsert_metadatas = [
        {"index": i, "upserted": True} for i in range(len(upsert_embeddings))
    ]

    collection.upsert(
        embeddings=upsert_embeddings,
        ids=ids_to_upsert_existing + ids_to_upsert_new,
        metadatas=upsert_metadatas,
    )

    valid_query_ids = (
        [f"{i}" for i in range(5, 10)]  # subset of existing ids
        + [f"{i}" for i in range(35, 45)]  # subset of existing, but upserted
        + [
            f"{i}" for i in range(num_vectors + 5, num_vectors + 15)
        ]  # subset of new upserted ids
    )

    includes = ["metadatas"]
    query_embedding = np.random.rand(dim).astype(np.float32).tolist()
    results = collection.query(
        query_embeddings=query_embedding,
        ids=valid_query_ids,
        n_results=new_num_vectors,
        include=includes,
    )

    all_returned_ids = [item for sublist in results["ids"] for item in sublist]
    assert all(id in valid_query_ids for id in all_returned_ids)

    for result_index, id_list in enumerate(results["ids"]):
        for item_index, item_id in enumerate(id_list):
            if item_id in ids_to_upsert_existing or item_id in ids_to_upsert_new:
                # checks if metadata correctly has upserted flag
                assert results["metadatas"][result_index][item_index]["upserted"]

    upserted_id = ids_to_upsert_existing[0]
    # test single id filtering
    results = collection.query(
        query_embeddings=query_embedding,
        ids=upserted_id,
        n_results=1,
        include=includes,
    )
    assert results["metadatas"][0][0]["upserted"]

    deleted_id = ids_to_delete[0]
    # test deleted id filter raises
    with pytest.raises(Exception) as error:
        collection.query(
            query_embeddings=query_embedding,
            ids=deleted_id,
            n_results=1,
            include=includes,
        )
    assert "Error finding id" in str(error.value)


def test_validate_sparse_vector():
    """Test SparseVector validation in __post_init__."""
    from chromadb.base_types import SparseVector

    # Test 1: Valid sparse vector - should not raise
    SparseVector(indices=[0, 2, 5], values=[0.1, 0.5, 0.9])

    # Test 2: Valid sparse vector with empty lists - should not raise
    SparseVector(indices=[], values=[])

    # Test 4: Invalid - indices not a list
    with pytest.raises(ValueError, match="Expected SparseVector indices to be a list"):
        SparseVector(indices="not_a_list", values=[0.1, 0.2])  # type: ignore

    # Test 5: Invalid - values not a list
    with pytest.raises(ValueError, match="Expected SparseVector values to be a list"):
        SparseVector(indices=[0, 1], values="not_a_list")  # type: ignore

    # Test 6: Invalid - mismatched lengths
    with pytest.raises(
        ValueError, match="indices and values must have the same length"
    ):
        SparseVector(indices=[0, 1, 2], values=[0.1, 0.2])

    # Test 7: Invalid - non-integer index
    with pytest.raises(ValueError, match="SparseVector indices must be integers"):
        SparseVector(indices=[0, "not_int", 2], values=[0.1, 0.2, 0.3])  # type: ignore

    # Test 8: Invalid - negative index
    with pytest.raises(ValueError, match="SparseVector indices must be non-negative"):
        SparseVector(indices=[0, -1, 2], values=[0.1, 0.2, 0.3])

    # Test 9: Invalid - non-numeric value
    with pytest.raises(ValueError, match="SparseVector values must be numbers"):
        SparseVector(indices=[0, 1, 2], values=[0.1, "not_number", 0.3])  # type: ignore

    # Test 10: Invalid - float indices (not integers)
    with pytest.raises(ValueError, match="SparseVector indices must be integers"):
        SparseVector(indices=[0.0, 1.0, 2.0], values=[0.1, 0.2, 0.3])  # type: ignore

    # Test 11: Valid - integer values (not just floats)
    SparseVector(indices=[0, 1, 2], values=[1, 2, 3])

    # Test 12: Valid - mixed int and float values
    SparseVector(indices=[0, 1, 2], values=[1, 2.5, 3])

    # Test 13: Valid - large indices
    SparseVector(indices=[100, 1000, 10000], values=[0.1, 0.2, 0.3])

    # Test 14: Invalid - None as value
    with pytest.raises(ValueError, match="SparseVector values must be numbers"):
        SparseVector(indices=[0, 1], values=[0.1, None])  # type: ignore

    # Test 15: Invalid - None as index
    with pytest.raises(ValueError, match="SparseVector indices must be integers"):
        SparseVector(indices=[0, None], values=[0.1, 0.2])  # type: ignore

    # Test 16: Valid - single element
    SparseVector(indices=[42], values=[3.14])

    # Test 17: Boolean values are actually valid (bool is subclass of int in Python)
    SparseVector(indices=[0, 1], values=[True, False])  # True=1, False=0

    # Test 18: Invalid - unsorted indices
    with pytest.raises(
        ValueError, match="indices must be sorted in strictly ascending order"
    ):
        SparseVector(indices=[0, 2, 1], values=[0.1, 0.2, 0.3])

    # Test 19: Invalid - duplicate indices (not strictly ascending)
    with pytest.raises(
        ValueError, match="indices must be sorted in strictly ascending order"
    ):
        SparseVector(indices=[0, 1, 1, 2], values=[0.1, 0.2, 0.3, 0.4])

    # Test 20: Invalid - descending order
    with pytest.raises(
        ValueError, match="indices must be sorted in strictly ascending order"
    ):
        SparseVector(indices=[5, 3, 1], values=[0.5, 0.3, 0.1])


def test_sparse_vector_in_metadata_validation():
    """Test that sparse vectors are properly validated in metadata."""
    from chromadb.api.types import validate_metadata
    from chromadb.base_types import SparseVector

    # Test 1: Valid metadata with sparse vectors
    sparse_vector_1 = SparseVector(indices=[0, 2, 5], values=[0.1, 0.5, 0.9])
    sparse_vector_2 = SparseVector(indices=[1, 3, 4], values=[0.2, 0.4, 0.6])

    metadata_1 = {
        "text": "document 1",
        "sparse_embedding": sparse_vector_1,
        "score": 0.5,
    }
    metadata_2 = {
        "text": "document 2",
        "sparse_embedding": sparse_vector_2,
        "score": 0.8,
    }
    validate_metadata(metadata_1)
    validate_metadata(metadata_2)

    # Test 2: Valid metadata with empty sparse vector
    metadata_empty = {
        "text": "empty sparse",
        "sparse_vec": SparseVector(indices=[], values=[]),
    }
    validate_metadata(metadata_empty)

    # Test 3: Invalid sparse vector in metadata (construction fails)
    with pytest.raises(
        ValueError, match="indices and values must have the same length"
    ):
        invalid_metadata = {
            "text": "invalid",
            "sparse_embedding": SparseVector(indices=[0, 1], values=[0.1]),
        }

    # Test 4: Invalid dict in metadata (not a SparseVector dataclass)
    invalid_metadata_2 = {
        "text": "missing indices",
        "sparse_embedding": {"values": [0.1, 0.2]},
    }
    with pytest.raises(
        ValueError,
        match="Expected metadata value to be a str, int, float, bool, SparseVector, or None",
    ):
        validate_metadata(invalid_metadata_2)

    # Test 5: Invalid sparse vector - negative index (construction fails)
    with pytest.raises(ValueError, match="SparseVector indices must be non-negative"):
        invalid_metadata_3 = {
            "text": "negative index",
            "sparse_embedding": SparseVector(
                indices=[0, -1, 2], values=[0.1, 0.2, 0.3]
            ),
        }

    # Test 6: Invalid sparse vector - non-numeric value (construction fails)
    with pytest.raises(ValueError, match="SparseVector values must be numbers"):
        invalid_metadata_4 = {
            "text": "non-numeric value",
            "sparse_embedding": SparseVector(indices=[0, 1], values=[0.1, "not_a_number"]),  # type: ignore
        }

    # Test 7: Multiple sparse vectors in metadata
    metadata_multiple = {
        "text": "multiple sparse vectors",
        "sparse_1": SparseVector(indices=[0, 1], values=[0.1, 0.2]),
        "sparse_2": SparseVector(indices=[2, 3, 4], values=[0.3, 0.4, 0.5]),
        "regular_field": 42,
    }
    validate_metadata(metadata_multiple)

    # Test 8: Regular dict (not SparseVector) should be rejected
    metadata_nested = {
        "config": "some_config",
        "sparse_vector": {"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]},
    }
    with pytest.raises(
        ValueError,
        match="Expected metadata value to be a str, int, float, bool, SparseVector, or None",
    ):
        validate_metadata(metadata_nested)

    # Test 9: Large sparse vector
    large_sparse = SparseVector(
        indices=list(range(1000)),
        values=[float(i) * 0.001 for i in range(1000)],
    )
    metadata_large = {"text": "large sparse", "large_sparse_vec": large_sparse}
    validate_metadata(metadata_large)


def test_sparse_vector_dict_format_normalization():
    """Test that dict-format sparse vectors are normalized to SparseVector instances."""
    from chromadb.api.types import normalize_metadata, validate_metadata
    from chromadb.base_types import SparseVector

    # Test 1: Dict format with #type='sparse_vector' should be converted
    metadata_dict_format = {
        "text": "test document",
        "sparse": {
            TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
            "indices": [0, 2, 5],
            "values": [1.0, 2.0, 3.0],
        },
    }
    normalized = normalize_metadata(metadata_dict_format)

    assert isinstance(normalized["sparse"], SparseVector)
    assert normalized["sparse"].indices == [0, 2, 5]
    assert normalized["sparse"].values == [1.0, 2.0, 3.0]

    # Should pass validation after normalization
    validate_metadata(normalized)

    # Test 2: SparseVector instance should pass through unchanged
    sparse_instance = SparseVector(indices=[1, 3, 4], values=[0.5, 1.5, 2.5])
    metadata_instance_format = {
        "text": "test document",
        "sparse": sparse_instance,
    }
    normalized2 = normalize_metadata(metadata_instance_format)

    assert normalized2["sparse"] is sparse_instance  # Same object
    validate_metadata(normalized2)

    # Test 3: Dict format with unsorted indices should be rejected during normalization
    metadata_unsorted = {
        "text": "unsorted",
        "sparse": {
            TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
            "indices": [5, 0, 2],
            "values": [3.0, 1.0, 2.0],
        },
    }
    with pytest.raises(
        ValueError, match="indices must be sorted in strictly ascending order"
    ):
        normalize_metadata(metadata_unsorted)

    # Test 4: Dict format with duplicate indices should be rejected
    metadata_duplicates = {
        "text": "duplicates",
        "sparse": {
            TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
            "indices": [0, 2, 2],
            "values": [1.0, 2.0, 3.0],
        },
    }
    with pytest.raises(
        ValueError, match="indices must be sorted in strictly ascending order"
    ):
        normalize_metadata(metadata_duplicates)

    # Test 5: Dict format with negative indices should be rejected
    metadata_negative = {
        "text": "negative",
        "sparse": {
            TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
            "indices": [-1, 0, 2],
            "values": [1.0, 2.0, 3.0],
        },
    }
    with pytest.raises(ValueError, match="indices must be non-negative"):
        normalize_metadata(metadata_negative)

    # Test 6: Dict format with length mismatch should be rejected
    metadata_mismatch = {
        "text": "mismatch",
        "sparse": {
            TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
            "indices": [0, 2],
            "values": [1.0, 2.0, 3.0],
        },
    }
    with pytest.raises(
        ValueError, match="indices and values must have the same length"
    ):
        normalize_metadata(metadata_mismatch)

    # Test 7: Regular dict without #type should not be converted
    metadata_regular_dict = {
        "text": "regular",
        "config": {"key": "value"},
    }
    normalized3 = normalize_metadata(metadata_regular_dict)
    assert isinstance(normalized3["config"], dict)
    assert normalized3["config"]["key"] == "value"

    # Test 8: Empty sparse vector in dict format
    metadata_empty = {
        "text": "empty",
        "sparse": {TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [], "values": []},
    }
    normalized4 = normalize_metadata(metadata_empty)
    assert isinstance(normalized4["sparse"], SparseVector)
    assert normalized4["sparse"].indices == []
    assert normalized4["sparse"].values == []

    # Test 9: Multiple sparse vectors in dict format
    metadata_multiple = {
        "sparse1": {
            TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
            "indices": [0, 1],
            "values": [1.0, 2.0],
        },
        "sparse2": {
            TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
            "indices": [2, 3],
            "values": [3.0, 4.0],
        },
        "regular": 42,
    }
    normalized5 = normalize_metadata(metadata_multiple)
    assert isinstance(normalized5["sparse1"], SparseVector)
    assert isinstance(normalized5["sparse2"], SparseVector)
    assert normalized5["regular"] == 42


def test_sparse_vector_dict_format_in_record_set():
    """Test that dict-format sparse vectors work in normalize_insert_record_set."""
    from chromadb.api.types import (
        normalize_insert_record_set,
        validate_insert_record_set,
    )
    from chromadb.base_types import SparseVector

    # Test 1: Mix of dict format and SparseVector instances
    record_set = normalize_insert_record_set(
        ids=["doc1", "doc2", "doc3"],
        embeddings=None,
        metadatas=[
            {
                "text": "test1",
                "sparse": {
                    TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
                    "indices": [0, 2],
                    "values": [1.0, 2.0],
                },
            },
            {
                "text": "test2",
                "sparse": SparseVector(indices=[1, 3], values=[1.5, 2.5]),
            },
            {"text": "test3"},  # No sparse vector
        ],
        documents=["doc one", "doc two", "doc three"],
    )

    # Both should be converted to SparseVector instances
    assert isinstance(record_set["metadatas"][0]["sparse"], SparseVector)
    assert isinstance(record_set["metadatas"][1]["sparse"], SparseVector)
    assert "sparse" not in record_set["metadatas"][2]

    # Validation should pass
    validate_insert_record_set(record_set)

    # Test 2: Verify values are correct after normalization
    assert record_set["metadatas"][0]["sparse"].indices == [0, 2]
    assert record_set["metadatas"][0]["sparse"].values == [1.0, 2.0]
    assert record_set["metadatas"][1]["sparse"].indices == [1, 3]
    assert record_set["metadatas"][1]["sparse"].values == [1.5, 2.5]


def test_search_result_rows() -> None:
    """Test the SearchResult.rows() method for converting column-major to row-major format."""
    from chromadb.api.types import SearchResult

    # Test 1: Basic single payload with all fields
    result = SearchResult(
        {
            "ids": [["id1", "id2", "id3"]],
            "documents": [["doc1", "doc2", "doc3"]],
            "embeddings": [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]],
            "metadatas": [[{"key": "a"}, {"key": "b"}, {"key": "c"}]],
            "scores": [[0.9, 0.8, 0.7]],
            "select": [["document", "score", "metadata"]],
        }
    )

    rows = result.rows()
    assert len(rows) == 1  # One payload
    assert len(rows[0]) == 3  # Three results

    # Check first row
    assert rows[0][0]["id"] == "id1"
    assert rows[0][0]["document"] == "doc1"
    assert rows[0][0]["embedding"] == [1.0, 2.0]
    assert rows[0][0]["metadata"] == {"key": "a"}
    assert rows[0][0]["score"] == 0.9

    # Check all rows have all fields
    for row in rows[0]:
        assert "id" in row
        assert "document" in row
        assert "embedding" in row
        assert "metadata" in row
        assert "score" in row

    # Test 2: Multiple payloads
    result = SearchResult(
        {
            "ids": [["a1", "a2"], ["b1", "b2", "b3"]],
            "documents": [["doc_a1", "doc_a2"], ["doc_b1", "doc_b2", "doc_b3"]],
            "embeddings": [
                None,
                [[1.0], [2.0], [3.0]],
            ],  # First payload has no embeddings
            "metadatas": [[{"x": 1}, {"x": 2}], None],  # Second payload has no metadata
            "scores": [[0.5, 0.4], [0.9, 0.8, 0.7]],
            "select": [["document", "score"], ["embedding", "score"]],
        }
    )

    rows = result.rows()
    assert len(rows) == 2  # Two payloads
    assert len(rows[0]) == 2  # First payload has 2 results
    assert len(rows[1]) == 3  # Second payload has 3 results

    # First payload - has docs, metadata, scores but no embeddings
    assert rows[0][0] == {
        "id": "a1",
        "document": "doc_a1",
        "metadata": {"x": 1},
        "score": 0.5,
    }
    assert rows[0][1] == {
        "id": "a2",
        "document": "doc_a2",
        "metadata": {"x": 2},
        "score": 0.4,
    }

    # Second payload - has docs, embeddings, scores but no metadata
    assert rows[1][0] == {
        "id": "b1",
        "document": "doc_b1",
        "embedding": [1.0],
        "score": 0.9,
    }
    assert rows[1][1] == {
        "id": "b2",
        "document": "doc_b2",
        "embedding": [2.0],
        "score": 0.8,
    }
    assert rows[1][2] == {
        "id": "b3",
        "document": "doc_b3",
        "embedding": [3.0],
        "score": 0.7,
    }

    # Test 3: Empty result
    result = SearchResult(
        {
            "ids": [],
            "documents": [],
            "embeddings": [],
            "metadatas": [],
            "scores": [],
            "select": [],
        }
    )

    rows = result.rows()
    assert rows == []

    # Test 4: Sparse data with None values in lists
    result = SearchResult(
        {
            "ids": [["id1", "id2", "id3"]],
            "documents": [[None, "doc2", None]],  # Sparse documents
            "embeddings": None,  # No embeddings at all
            "metadatas": [[{"a": 1}, None, {"c": 3}]],  # Sparse metadata
            "scores": [[0.9, None, 0.7]],  # Sparse scores
            "select": [["document", "metadata", "score"]],
        }
    )

    rows = result.rows()
    assert len(rows) == 1
    assert len(rows[0]) == 3

    # First row - only has metadata and score
    assert rows[0][0] == {"id": "id1", "metadata": {"a": 1}, "score": 0.9}

    # Second row - only has document
    assert rows[0][1] == {"id": "id2", "document": "doc2"}

    # Third row - has metadata and score
    assert rows[0][2] == {"id": "id3", "metadata": {"c": 3}, "score": 0.7}

    # Test 5: Only IDs (minimal result)
    result = SearchResult(
        {
            "ids": [["id1", "id2"]],
            "documents": None,
            "embeddings": None,
            "metadatas": None,
            "scores": None,
            "select": [[]],
        }
    )

    rows = result.rows()
    assert len(rows) == 1
    assert len(rows[0]) == 2
    assert rows[0][0] == {"id": "id1"}
    assert rows[0][1] == {"id": "id2"}

    # Test 6: SearchResult works as dict (backward compatibility)
    result = SearchResult(
        {
            "ids": [["test"]],
            "documents": [["test doc"]],
            "metadatas": [[{"test": True}]],
            "embeddings": [[[0.1, 0.2]]],
            "scores": [[0.99]],
            "select": [["all"]],
        }
    )

    # Should work as dict
    assert result["ids"] == [["test"]]
    assert result.get("documents") == [["test doc"]]
    assert "metadatas" in result
    assert len(result) == 6  # Should have 6 keys

    # Should also have rows() method
    rows = result.rows()
    assert len(rows[0]) == 1
    assert rows[0][0]["id"] == "test"

    print("All SearchResult.rows() tests passed!")


def test_rrf_to_dict() -> None:
    """Test the Rrf (Reciprocal Rank Fusion) to_dict conversion."""
    # Note: In these tests, "sparse_embedding" is just an example metadata field name.
    # Users can store any data in metadata fields and reference them by name (without # prefix).
    # The "#embedding" key refers to the special main embedding field.

    import pytest
    from chromadb.execution.expression.operator import Rrf, Knn, Val

    # Test 1: Basic RRF with two KNN rankings (equal weight)
    rrf = Rrf(
        [
            Knn(query=[0.1, 0.2], return_rank=True),
            Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
        ]
    )

    result = rrf.to_dict()

    # RRF formula: -sum(weight_i / (k + rank_i))
    # With default k=60 and equal weights (1.0 each)
    # Expected: -(1.0/(60 + knn1) + 1.0/(60 + knn2))
    expected = {
        "$mul": [
            {"$val": -1},
            {
                "$sum": [
                    {
                        "$div": {
                            "left": {"$val": 1.0},
                            "right": {
                                "$sum": [
                                    {"$val": 60},
                                    {
                                        "$knn": {
                                            "query": [0.1, 0.2],
                                            "key": "#embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                    {
                        "$div": {
                            "left": {"$val": 1.0},
                            "right": {
                                "$sum": [
                                    {"$val": 60},
                                    {
                                        "$knn": {
                                            "query": [0.3, 0.4],
                                            "key": "sparse_embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                ]
            },
        ]
    }

    assert result == expected

    # Test 2: RRF with custom weights and k
    rrf_weighted = Rrf(
        ranks=[
            Knn(query=[0.1, 0.2], return_rank=True),
            Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
        ],
        weights=[2.0, 1.0],  # Dense is 2x more important
        k=100,
    )

    result_weighted = rrf_weighted.to_dict()

    # Expected: -(2.0/(100 + knn1) + 1.0/(100 + knn2))
    expected_weighted = {
        "$mul": [
            {"$val": -1},
            {
                "$sum": [
                    {
                        "$div": {
                            "left": {"$val": 2.0},
                            "right": {
                                "$sum": [
                                    {"$val": 100},
                                    {
                                        "$knn": {
                                            "query": [0.1, 0.2],
                                            "key": "#embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                    {
                        "$div": {
                            "left": {"$val": 1.0},
                            "right": {
                                "$sum": [
                                    {"$val": 100},
                                    {
                                        "$knn": {
                                            "query": [0.3, 0.4],
                                            "key": "sparse_embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                ]
            },
        ]
    }

    assert result_weighted == expected_weighted

    # Test 3: RRF with three rankings
    rrf_three = Rrf(
        [
            Knn(query=[0.1, 0.2], return_rank=True),
            Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
            Val(5.0),  # Can also include constant rank
        ]
    )

    result_three = rrf_three.to_dict()

    # Verify it has three terms in the sum
    assert "$mul" in result_three
    assert "$sum" in result_three["$mul"][1]
    terms = result_three["$mul"][1]["$sum"]
    assert len(terms) == 3  # Three ranking strategies

    # Test 4: Error case - mismatched weights
    with pytest.raises(
        ValueError, match="Number of weights .* must match number of ranks"
    ):
        rrf_bad = Rrf(
            ranks=[
                Knn(query=[0.1, 0.2], return_rank=True),
                Knn(query=[0.3, 0.4], return_rank=True),
            ],
            weights=[1.0],  # Only one weight for two ranks
        )
        rrf_bad.to_dict()

    # Test 5: Error case - negative weights
    with pytest.raises(ValueError, match="All weights must be non-negative"):
        rrf_negative = Rrf(
            ranks=[
                Knn(query=[0.1, 0.2], return_rank=True),
                Knn(query=[0.3, 0.4], return_rank=True),
            ],
            weights=[1.0, -1.0],  # Negative weight
        )
        rrf_negative.to_dict()

    # Test 6: Error case - empty ranks list
    with pytest.raises(ValueError, match="RRF requires at least one rank"):
        rrf_empty = Rrf([])
        rrf_empty.to_dict()  # Validation happens in to_dict()

    # Test 7: Error case - negative k value
    with pytest.raises(ValueError, match="k must be positive"):
        rrf_neg_k = Rrf([Val(1.0)], k=-5)
        rrf_neg_k.to_dict()  # Validation happens in to_dict()

    # Test 8: Error case - zero k value
    with pytest.raises(ValueError, match="k must be positive"):
        rrf_zero_k = Rrf([Val(1.0)], k=0)
        rrf_zero_k.to_dict()  # Validation happens in to_dict()
    # Test 9: Normalize flag with weights
    rrf_normalized = Rrf(
        ranks=[
            Knn(query=[0.1, 0.2], return_rank=True),
            Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
        ],
        weights=[3.0, 1.0],  # Will be normalized to [0.75, 0.25]
        normalize=True,
        k=100,
    )

    result_normalized = rrf_normalized.to_dict()

    # Expected: -(0.75/(100 + knn1) + 0.25/(100 + knn2))
    expected_normalized = {
        "$mul": [
            {"$val": -1},
            {
                "$sum": [
                    {
                        "$div": {
                            "left": {"$val": 0.75},
                            "right": {
                                "$sum": [
                                    {"$val": 100},
                                    {
                                        "$knn": {
                                            "query": [0.1, 0.2],
                                            "key": "#embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                    {
                        "$div": {
                            "left": {"$val": 0.25},
                            "right": {
                                "$sum": [
                                    {"$val": 100},
                                    {
                                        "$knn": {
                                            "query": [0.3, 0.4],
                                            "key": "sparse_embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                ]
            },
        ]
    }

    assert result_normalized == expected_normalized

    # Test 10: Normalize flag without weights (should work with defaults)
    rrf_normalize_defaults = Rrf(
        ranks=[
            Knn(query=[0.1, 0.2], return_rank=True),
            Knn(query=[0.3, 0.4], return_rank=True),
        ],
        normalize=True,  # Will normalize [1.0, 1.0] to [0.5, 0.5]
    )

    result_defaults = rrf_normalize_defaults.to_dict()

    # Both weights should be 0.5 after normalization
    expected_defaults = {
        "$mul": [
            {"$val": -1},
            {
                "$sum": [
                    {
                        "$div": {
                            "left": {"$val": 0.5},
                            "right": {
                                "$sum": [
                                    {"$val": 60},  # Default k=60
                                    {
                                        "$knn": {
                                            "query": [0.1, 0.2],
                                            "key": "#embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                    {
                        "$div": {
                            "left": {"$val": 0.5},
                            "right": {
                                "$sum": [
                                    {"$val": 60},
                                    {
                                        "$knn": {
                                            "query": [0.3, 0.4],
                                            "key": "#embedding",
                                            "limit": 16,
                                            "return_rank": True,
                                        }
                                    },
                                ]
                            },
                        }
                    },
                ]
            },
        ]
    }

    assert result_defaults == expected_defaults

    # Test 11: Error case - normalize with all zero weights
    with pytest.raises(ValueError, match="Sum of weights must be positive"):
        rrf_zero_weights = Rrf(
            ranks=[
                Knn(query=[0.1, 0.2], return_rank=True),
                Knn(query=[0.3, 0.4], return_rank=True),
            ],
            weights=[0.0, 0.0],
            normalize=True,
        )
        rrf_zero_weights.to_dict()

    print("All RRF tests passed!")


# Expression API Tests - Testing dict support and from_dict methods
class TestSearchDictSupport:
    """Test Search class dict input support."""

    def test_search_with_dict_where(self):
        """Test Search accepts dict for where parameter."""
        from chromadb.execution.expression.plan import Search
        from chromadb.execution.expression.operator import Where

        # Simple equality
        search = Search(where={"status": "active"})
        assert search._where is not None
        assert isinstance(search._where, Where)

        # Complex where with operators
        search = Search(where={"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
        assert search._where is not None

    def test_search_with_dict_rank(self):
        """Test Search accepts dict for rank parameter."""
        from chromadb.execution.expression.plan import Search
        from chromadb.execution.expression.operator import Rank

        # KNN ranking
        search = Search(rank={"$knn": {"query": [0.1, 0.2]}})
        assert search._rank is not None
        assert isinstance(search._rank, Rank)

        # Val ranking
        search = Search(rank={"$val": 0.5})
        assert search._rank is not None

    def test_search_with_dict_limit(self):
        """Test Search accepts dict and int for limit parameter."""
        from chromadb.execution.expression.plan import Search

        # Dict limit
        search = Search(limit={"limit": 10, "offset": 5})
        assert search._limit.limit == 10
        assert search._limit.offset == 5

        # Int limit (creates Limit with offset=0)
        search = Search(limit=10)
        assert search._limit.limit == 10
        assert search._limit.offset == 0

    def test_search_with_dict_select(self):
        """Test Search accepts dict, list, and set for select parameter."""
        from chromadb.execution.expression.plan import Search

        # Dict select
        search = Search(select={"keys": ["#document", "#score"]})
        assert search._select is not None

        # List select
        search = Search(select=["#document", "#metadata"])
        assert search._select is not None

        # Set select
        search = Search(select={"#document", "#embedding"})
        assert search._select is not None

    def test_search_mixed_inputs(self):
        """Test Search with mixed expression and dict inputs."""
        from chromadb.execution.expression.plan import Search
        from chromadb.execution.expression.operator import Key

        search = Search(
            where=Key("status") == "active",  # Expression
            rank={"$knn": {"query": [0.1, 0.2]}},  # Dict
            limit=10,  # Int
            select=["#document"],  # List
        )
        assert search._where is not None
        assert search._rank is not None
        assert search._limit.limit == 10
        assert search._select is not None

    def test_search_builder_methods_with_dicts(self):
        """Test Search builder methods accept dicts."""
        from chromadb.execution.expression.plan import Search

        search = Search().where({"status": "active"}).rank({"$val": 0.5})
        assert search._where is not None
        assert search._rank is not None

    def test_search_invalid_inputs(self):
        """Test Search rejects invalid input types."""
        import pytest
        from chromadb.execution.expression.plan import Search

        with pytest.raises(TypeError, match="where must be"):
            Search(where="invalid")

        with pytest.raises(TypeError, match="rank must be"):
            Search(rank=0.5)  # Primitive numbers not allowed

        with pytest.raises(TypeError, match="limit must be"):
            Search(limit="10")

        with pytest.raises(TypeError, match="select must be"):
            Search(select=123)


class TestWhereFromDict:
    """Test Where.from_dict() conversion."""

    def test_simple_equality(self):
        """Test simple equality conversion."""
        from chromadb.execution.expression.operator import Where, Eq

        # Shorthand for equality
        where = Where.from_dict({"status": "active"})
        assert isinstance(where, Eq)

        # Explicit $eq
        where = Where.from_dict({"status": {"$eq": "active"}})
        assert isinstance(where, Eq)

    def test_comparison_operators(self):
        """Test comparison operator conversions."""
        from chromadb.execution.expression.operator import Where, Ne, Gt, Gte, Lt, Lte

        # $ne
        where = Where.from_dict({"status": {"$ne": "inactive"}})
        assert isinstance(where, Ne)

        # $gt
        where = Where.from_dict({"score": {"$gt": 0.5}})
        assert isinstance(where, Gt)

        # $gte
        where = Where.from_dict({"score": {"$gte": 0.5}})
        assert isinstance(where, Gte)

        # $lt
        where = Where.from_dict({"score": {"$lt": 1.0}})
        assert isinstance(where, Lt)

        # $lte
        where = Where.from_dict({"score": {"$lte": 1.0}})
        assert isinstance(where, Lte)

    def test_membership_operators(self):
        """Test membership operator conversions."""
        from chromadb.execution.expression.operator import Where, In, Nin

        # $in
        where = Where.from_dict({"status": {"$in": ["active", "pending"]}})
        assert isinstance(where, In)

        # $nin (not in)
        where = Where.from_dict({"status": {"$nin": ["deleted", "archived"]}})
        assert isinstance(where, Nin)

    def test_string_operators(self):
        """Test string operator conversions."""
        from chromadb.execution.expression.operator import (
            Where,
            Contains,
            NotContains,
            Regex,
            NotRegex,
        )

        # $contains
        where = Where.from_dict({"text": {"$contains": "hello"}})
        assert isinstance(where, Contains)

        # $not_contains
        where = Where.from_dict({"text": {"$not_contains": "spam"}})
        assert isinstance(where, NotContains)

        # $regex
        where = Where.from_dict({"text": {"$regex": "^test.*"}})
        assert isinstance(where, Regex)

        # $not_regex
        where = Where.from_dict({"text": {"$not_regex": r"\d+"}})
        assert isinstance(where, NotRegex)

    def test_logical_operators(self):
        """Test logical operator conversions."""
        from chromadb.execution.expression.operator import Where, And, Or

        # $and
        where = Where.from_dict(
            {"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]}
        )
        assert isinstance(where, And)

        # $or
        where = Where.from_dict({"$or": [{"status": "active"}, {"status": "pending"}]})
        assert isinstance(where, Or)

    def test_nested_logical_operators(self):
        """Test nested logical operations."""
        from chromadb.execution.expression.operator import Where, And

        where = Where.from_dict(
            {
                "$and": [
                    {"$or": [{"status": "active"}, {"status": "pending"}]},
                    {"score": {"$gte": 0.5}},
                ]
            }
        )
        assert isinstance(where, And)

    def test_special_keys(self):
        """Test special key handling."""
        from chromadb.execution.expression.operator import Where, In

        # ID key
        where = Where.from_dict({"#id": {"$in": ["id1", "id2"]}})
        assert isinstance(where, In)

    def test_invalid_where_dicts(self):
        """Test invalid Where dict inputs."""
        import pytest
        from chromadb.execution.expression.operator import Where

        with pytest.raises(TypeError, match="Expected dict"):
            Where.from_dict("not a dict")

        with pytest.raises(ValueError, match="cannot be empty"):
            Where.from_dict({})

        with pytest.raises(ValueError, match="requires at least one condition"):
            Where.from_dict({"$and": []})


class TestRankFromDict:
    """Test Rank.from_dict() conversion."""

    def test_val_conversion(self):
        """Test Val conversion."""
        from chromadb.execution.expression.operator import Rank, Val

        rank = Rank.from_dict({"$val": 0.5})
        assert isinstance(rank, Val)
        assert rank.value == 0.5

    def test_knn_conversion(self):
        """Test KNN conversion."""
        import numpy as np
        from chromadb.execution.expression.operator import Rank, Knn

        # Basic KNN with defaults
        rank = Rank.from_dict({"$knn": {"query": [0.1, 0.2]}})
        assert isinstance(rank, Knn)
        # Handle both list and numpy array cases
        if isinstance(rank.query, np.ndarray):
            # Use allclose for floating point comparison with dtype tolerance
            assert np.allclose(rank.query, np.array([0.1, 0.2]))
        else:
            assert rank.query == [0.1, 0.2]
        assert rank.key == "#embedding"  # default
        assert rank.limit == 16  # default

        # KNN with custom parameters
        rank = Rank.from_dict(
            {
                "$knn": {
                    "query": [0.1, 0.2],
                    "key": "sparse_embedding",
                    "limit": 256,
                    "return_rank": True,
                }
            }
        )
        assert rank.key == "sparse_embedding"
        assert rank.limit == 256
        assert rank.return_rank

    def test_arithmetic_operators(self):
        """Test arithmetic operator conversions."""
        from chromadb.execution.expression.operator import Rank, Sum, Sub, Mul, Div

        # $sum
        rank = Rank.from_dict({"$sum": [{"$val": 0.5}, {"$val": 0.3}]})
        assert isinstance(rank, Sum)

        # $sub
        rank = Rank.from_dict({"$sub": {"left": {"$val": 1.0}, "right": {"$val": 0.3}}})
        assert isinstance(rank, Sub)

        # $mul
        rank = Rank.from_dict({"$mul": [{"$val": 2.0}, {"$val": 0.5}]})
        assert isinstance(rank, Mul)

        # $div
        rank = Rank.from_dict({"$div": {"left": {"$val": 1.0}, "right": {"$val": 2.0}}})
        assert isinstance(rank, Div)

    def test_math_functions(self):
        """Test math function conversions."""
        from chromadb.execution.expression.operator import Rank, Abs, Exp, Log

        # $abs
        rank = Rank.from_dict({"$abs": {"$val": -0.5}})
        assert isinstance(rank, Abs)

        # $exp
        rank = Rank.from_dict({"$exp": {"$val": 1.0}})
        assert isinstance(rank, Exp)

        # $log
        rank = Rank.from_dict({"$log": {"$val": 2.0}})
        assert isinstance(rank, Log)

    def test_aggregation_functions(self):
        """Test min/max conversions."""
        from chromadb.execution.expression.operator import Rank, Max, Min

        # $max
        rank = Rank.from_dict({"$max": [{"$val": 0.5}, {"$val": 0.8}]})
        assert isinstance(rank, Max)

        # $min
        rank = Rank.from_dict({"$min": [{"$val": 0.5}, {"$val": 0.8}]})
        assert isinstance(rank, Min)

    def test_complex_rank_expression(self):
        """Test complex nested rank expressions."""
        from chromadb.execution.expression.operator import Rank, Sum

        rank = Rank.from_dict(
            {
                "$sum": [
                    {"$mul": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.8}]},
                    {"$mul": [{"$val": 0.5}, {"$val": 0.2}]},
                ]
            }
        )
        assert isinstance(rank, Sum)

    def test_invalid_rank_dicts(self):
        """Test invalid Rank dict inputs."""
        import pytest
        from chromadb.execution.expression.operator import Rank

        with pytest.raises(TypeError, match="Expected dict"):
            Rank.from_dict("not a dict")

        with pytest.raises(ValueError, match="cannot be empty"):
            Rank.from_dict({})

        with pytest.raises(ValueError, match="exactly one operator"):
            Rank.from_dict({"$val": 0.5, "$knn": {"query": [0.1]}})

        with pytest.raises(TypeError, match="requires a number"):
            Rank.from_dict({"$val": "not a number"})


class TestLimitFromDict:
    """Test Limit.from_dict() conversion."""

    def test_limit_only(self):
        """Test limit without offset."""
        from chromadb.execution.expression.operator import Limit

        limit = Limit.from_dict({"limit": 20})
        assert limit.limit == 20
        assert limit.offset == 0  # default

    def test_offset_only(self):
        """Test offset without limit."""
        from chromadb.execution.expression.operator import Limit

        limit = Limit.from_dict({"offset": 10})
        assert limit.offset == 10
        assert limit.limit is None

    def test_limit_and_offset(self):
        """Test both limit and offset."""
        from chromadb.execution.expression.operator import Limit

        limit = Limit.from_dict({"limit": 20, "offset": 10})
        assert limit.limit == 20
        assert limit.offset == 10

    def test_validation(self):
        """Test Limit validation."""
        import pytest
        from chromadb.execution.expression.operator import Limit

        # Negative limit
        with pytest.raises(ValueError, match="must be positive"):
            Limit.from_dict({"limit": -1})

        # Zero limit
        with pytest.raises(ValueError, match="must be positive"):
            Limit.from_dict({"limit": 0})

        # Negative offset
        with pytest.raises(ValueError, match="must be non-negative"):
            Limit.from_dict({"offset": -1})

    def test_invalid_types(self):
        """Test type validation."""
        import pytest
        from chromadb.execution.expression.operator import Limit

        with pytest.raises(TypeError, match="Expected dict"):
            Limit.from_dict("not a dict")

        with pytest.raises(TypeError, match="must be an integer"):
            Limit.from_dict({"limit": "20"})

        with pytest.raises(TypeError, match="must be an integer"):
            Limit.from_dict({"offset": 10.5})

    def test_unexpected_keys(self):
        """Test rejection of unexpected keys."""
        import pytest
        from chromadb.execution.expression.operator import Limit

        with pytest.raises(ValueError, match="Unexpected keys"):
            Limit.from_dict({"limit": 10, "invalid": "key"})


class TestSelectFromDict:
    """Test Select.from_dict() conversion."""

    def test_special_keys(self):
        """Test special key conversion."""
        from chromadb.execution.expression.operator import Select, Key

        select = Select.from_dict(
            {"keys": ["#document", "#embedding", "#metadata", "#score"]}
        )
        assert Key.DOCUMENT in select.keys
        assert Key.EMBEDDING in select.keys
        assert Key.METADATA in select.keys
        assert Key.SCORE in select.keys

    def test_metadata_keys(self):
        """Test regular metadata field keys."""
        from chromadb.execution.expression.operator import Select, Key

        select = Select.from_dict({"keys": ["title", "author", "date"]})
        assert Key("title") in select.keys
        assert Key("author") in select.keys
        assert Key("date") in select.keys

    def test_mixed_keys(self):
        """Test mix of special and metadata keys."""
        from chromadb.execution.expression.operator import Select, Key

        select = Select.from_dict({"keys": ["#document", "title", "#score"]})
        assert Key.DOCUMENT in select.keys
        assert Key("title") in select.keys
        assert Key.SCORE in select.keys

    def test_empty_keys(self):
        """Test empty keys list."""
        from chromadb.execution.expression.operator import Select

        select = Select.from_dict({"keys": []})
        assert len(select.keys) == 0

    def test_validation(self):
        """Test Select validation."""
        import pytest
        from chromadb.execution.expression.operator import Select

        with pytest.raises(TypeError, match="Expected dict"):
            Select.from_dict("not a dict")

        with pytest.raises(TypeError, match="must be a list/tuple/set"):
            Select.from_dict({"keys": "not a list"})

        with pytest.raises(TypeError, match="must be a string"):
            Select.from_dict({"keys": [123]})

    def test_unexpected_keys(self):
        """Test rejection of unexpected keys."""
        import pytest
        from chromadb.execution.expression.operator import Select

        with pytest.raises(ValueError, match="Unexpected keys"):
            Select.from_dict({"keys": [], "invalid": "key"})


class TestRoundTripConversion:
    """Test that to_dict() and from_dict() round-trip correctly."""

    def test_where_round_trip(self):
        """Test Where round-trip conversion."""
        from chromadb.execution.expression.operator import Where, And, Key

        original = And([Key("status") == "active", Key("score") > 0.5])
        dict_form = original.to_dict()
        restored = Where.from_dict(dict_form)
        assert restored.to_dict() == dict_form

    def test_rank_round_trip(self):
        """Test Rank round-trip conversion."""
        import numpy as np
        from chromadb.execution.expression.operator import Rank, Knn, Val

        original = Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2
        dict_form = original.to_dict()
        restored = Rank.from_dict(dict_form)
        restored_dict = restored.to_dict()

        # Compare with float32 precision tolerance for KNN queries
        # The normalize_embeddings function converts to float32, causing precision differences
        def compare_dicts(d1, d2):
            if isinstance(d1, dict) and isinstance(d2, dict):
                if "$knn" in d1 and "$knn" in d2:
                    # Special handling for KNN queries
                    knn1, knn2 = d1["$knn"], d2["$knn"]
                    if "query" in knn1 and "query" in knn2:
                        # Compare queries with float32 precision
                        q1 = np.array(knn1["query"], dtype=np.float32)
                        q2 = np.array(knn2["query"], dtype=np.float32)
                        if not np.allclose(q1, q2):
                            return False
                        # Compare other fields exactly
                        for key in knn1:
                            if key != "query" and knn1[key] != knn2.get(key):
                                return False
                        return True

                # Recursively compare other dict structures
                if set(d1.keys()) != set(d2.keys()):
                    return False
                for key in d1:
                    if not compare_dicts(d1[key], d2[key]):
                        return False
                return True
            elif isinstance(d1, list) and isinstance(d2, list):
                if len(d1) != len(d2):
                    return False
                return all(compare_dicts(a, b) for a, b in zip(d1, d2))
            else:
                return d1 == d2

        assert compare_dicts(restored_dict, dict_form)

    def test_limit_round_trip(self):
        """Test Limit round-trip conversion."""
        from chromadb.execution.expression.operator import Limit

        original = Limit(limit=20, offset=10)
        dict_form = original.to_dict()
        restored = Limit.from_dict(dict_form)
        assert restored.to_dict() == dict_form

    def test_select_round_trip(self):
        """Test Select round-trip conversion."""
        from chromadb.execution.expression.operator import Select, Key

        original = Select(keys={Key.DOCUMENT, Key("title"), Key.SCORE})
        dict_form = original.to_dict()
        restored = Select.from_dict(dict_form)
        # Note: Set order might differ, so compare sets
        assert set(restored.to_dict()["keys"]) == set(dict_form["keys"])

    def test_search_round_trip(self):
        """Test Search round-trip through dict inputs."""
        import numpy as np
        from chromadb.execution.expression.plan import Search
        from chromadb.execution.expression.operator import Key, Knn, Limit, Select

        original_search = Search(
            where=Key("status") == "active",
            rank=Knn(query=[0.1, 0.2]),
            limit=Limit(limit=10),
            select=Select(keys={Key.DOCUMENT}),
        )

        # Convert to dict
        search_dict = original_search.to_dict()

        # Create new Search from dicts
        new_search = Search(
            where=search_dict["filter"] if search_dict["filter"] else None,
            rank=search_dict["rank"] if search_dict["rank"] else None,
            limit=search_dict["limit"],
            select=search_dict["select"],
        )

        # Get new dict
        new_dict = new_search.to_dict()

        # Compare with float32 tolerance for KNN queries
        # Use the same comparison function as test_rank_round_trip
        def compare_search_dicts(d1, d2):
            if isinstance(d1, dict) and isinstance(d2, dict):
                # Special handling for rank field with KNN
                if "rank" in d1 and "rank" in d2:
                    rank1, rank2 = d1["rank"], d2["rank"]
                    if isinstance(rank1, dict) and isinstance(rank2, dict):
                        if "$knn" in rank1 and "$knn" in rank2:
                            knn1, knn2 = rank1["$knn"], rank2["$knn"]
                            if "query" in knn1 and "query" in knn2:
                                q1 = np.array(knn1["query"], dtype=np.float32)
                                q2 = np.array(knn2["query"], dtype=np.float32)
                                if not np.allclose(q1, q2):
                                    return False
                                # Compare other KNN fields
                                for key in knn1:
                                    if key != "query" and knn1[key] != knn2.get(key):
                                        return False
                                # Compare other fields in the dict
                                for key in d1:
                                    if key != "rank" and d1[key] != d2.get(key):
                                        return False
                                return True

                # Normal dict comparison
                if set(d1.keys()) != set(d2.keys()):
                    return False
                for key in d1:
                    if isinstance(d1[key], dict) and isinstance(d2[key], dict):
                        if not compare_search_dicts(d1[key], d2[key]):
                            return False
                    elif d1[key] != d2[key]:
                        return False
                return True
            else:
                return d1 == d2

        assert compare_search_dicts(new_dict, search_dict)
