cognee/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py
2025-09-11 14:06:02 +01:00

565 lines
19 KiB
Python

import json
import asyncio
from uuid import UUID
from typing import List, Optional, Dict, Any
from chromadb import AsyncHttpClient, Settings
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
logger = get_logger("ChromaDBAdapter")
class IndexSchema(DataPoint):
"""
Define a schema for indexing textual data along with metadata.
Public methods:
- model_dump: Serialize the instance data into a format suitable for ChromaDB storage.
Instance variables:
- text: The text content to be indexed.
- metadata: A dictionary containing index-related fields.
"""
text: str
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""
Serialize the instance data for storage.
Invoke the superclass method and process the resulting data into a format compatible
with ChromaDB.
Returns:
--------
A dictionary containing serialized data processed for ChromaDB storage.
"""
data = super().model_dump(**kwargs)
return process_data_for_chroma(data)
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert complex data types to a format suitable for ChromaDB storage.
This function processes various data types from the input dictionary, converting UUIDs
to strings, and serializing dictionaries and lists into JSON strings prefixed with the
key type. Other supported types (strings, integers, floats, booleans, and None) are
stored as-is. Unsupported types are converted to their string representation.
Parameters:
-----------
- data: A dictionary containing data with various types including potentially
complex structures.
Returns:
--------
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
"""
processed_data: Dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UUID):
processed_data[key] = str(value)
elif isinstance(value, dict):
# Store dictionaries as JSON strings with special prefix
processed_data[f"{key}__dict"] = json.dumps(value)
elif isinstance(value, list):
# Store lists as JSON strings with special prefix
processed_data[f"{key}__list"] = json.dumps(value)
elif isinstance(value, (str, int, float, bool)) or value is None:
processed_data[key] = value
else:
processed_data[key] = str(value)
return processed_data
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Restore original data structure from ChromaDB storage format.
This function processes a data dictionary, identifies special keys that end with
'__dict' and '__list', and attempts to restore the original data structure by converting
JSON strings back into Python objects. If restoration fails, it logs an error and
retains the original value.
Parameters:
-----------
- data: A dictionary containing data stored in ChromaDB format, which includes
special keys indicating dictionary and list entries.
Returns:
--------
A dictionary representing the restored original data structure, with JSON strings
converted to their appropriate Python types.
"""
restored_data = {}
dict_keys = []
list_keys = []
# First, identify all special keys
for key in data.keys():
if key.endswith("__dict"):
dict_keys.append(key)
elif key.endswith("__list"):
list_keys.append(key)
else:
restored_data[key] = data[key]
# Process dictionary fields
for key in dict_keys:
original_key = key[:-6] # Remove '__dict' suffix
try:
restored_data[original_key] = json.loads(data[key])
except Exception as e:
logger.debug(f"Error restoring dictionary from JSON: {e}")
restored_data[key] = data[key]
# Process list fields
for key in list_keys:
original_key = key[:-6] # Remove '__list' suffix
try:
restored_data[original_key] = json.loads(data[key])
except Exception as e:
logger.debug(f"Error restoring list from JSON: {e}")
restored_data[key] = data[key]
return restored_data
class ChromaDBAdapter(VectorDBInterface):
"""
Manage a connection to the ChromaDB and facilitate operations for embedding, searching,
and managing collections of data points.
"""
name = "ChromaDB"
url: str | None
api_key: str | None
connection: AsyncHttpClient = None
def __init__(
self, url: Optional[str], api_key: Optional[str], embedding_engine: EmbeddingEngine
):
self.embedding_engine = embedding_engine
self.url = url
self.api_key = api_key
self.VECTOR_DB_LOCK = asyncio.Lock()
async def get_connection(self) -> AsyncHttpClient:
"""
Establish and return a connection to the ChromaDB if one doesn't already exist.
Returns:
--------
- AsyncHttpClient: Returns an instance of AsyncHttpClient for interacting with
ChromaDB.
"""
if self.connection is None:
settings = Settings(
chroma_client_auth_provider="token", chroma_client_auth_credentials=self.api_key
)
self.connection = await AsyncHttpClient(host=self.url, settings=settings)
return self.connection
async def embed_data(self, data: list[str]) -> list[list[float]]:
"""
Embed a list of text data into vector representations.
Parameters:
-----------
- data (list[str]): A list of strings to be embedded.
Returns:
--------
- list[list[float]]: Returns a list of lists containing the embedded vector
representations.
"""
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
"""
Check if a collection with the specified name exists in the ChromaDB.
Parameters:
-----------
- collection_name (str): The name of the collection to check for.
Returns:
--------
- bool: Returns True if the collection exists, otherwise False.
"""
collections = await self.get_collection_names()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
"""
Create a new collection in ChromaDB if it does not already exist.
Parameters:
-----------
- collection_name (str): The name of the collection to create.
- payload_schema: The schema for the payload; can be None. (default None)
"""
async with self.VECTOR_DB_LOCK:
client = await self.get_connection()
if not await self.has_collection(collection_name):
await client.create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
async def get_collection(self, collection_name: str) -> AsyncHttpClient:
"""
Retrieve a collection by its name from ChromaDB.
Parameters:
-----------
- collection_name (str): The name of the collection to retrieve.
Returns:
--------
- AsyncHttpClient: Returns an AsyncHttpClient representing the requested collection.
"""
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
client = await self.get_connection()
return await client.get_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
"""
Create and upsert data points into the specified collection in ChromaDB.
Parameters:
-----------
- collection_name (str): The name of the collection where data points will be added.
- data_points (list[DataPoint]): A list of DataPoint instances to be added to the
collection.
"""
await self.create_collection(collection_name)
collection = await self.get_collection(collection_name)
texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
embeddings = await self.embed_data(texts)
ids = [str(data_point.id) for data_point in data_points]
metadatas = []
for data_point in data_points:
metadata = get_own_properties(data_point)
metadatas.append(process_data_for_chroma(metadata))
await collection.upsert(
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
)
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
"""
Create a vector index as a ChromaDB collection based on provided names.
Parameters:
-----------
- index_name (str): The base name for the vector index.
- index_property_name (str): The property name associated with the index.
"""
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
) -> None:
"""
Index the provided data points based on the specified index property in ChromaDB.
Parameters:
-----------
- index_name (str): The index name used for the data points.
- index_property_name (str): The property name to index data points by.
- data_points (list[DataPoint]): A list of DataPoint instances to be indexed.
"""
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
text=getattr(
data_point,
data_point.metadata["index_fields"][0]
),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
],
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
"""
Retrieve data points by their IDs from a ChromaDB collection.
Parameters:
-----------
- collection_name (str): The name of the collection from which to retrieve data
points.
- data_point_ids (list[str]): A list of data point IDs to retrieve.
Returns:
--------
Returns a list of ScoredResult instances containing the retrieved data points and
their metadata.
"""
collection = await self.get_collection(collection_name)
results = await collection.get(ids=data_point_ids, include=["metadatas"])
return [
ScoredResult(
id=parse_id(id),
payload=restore_data_from_chroma(metadata),
score=0,
)
for id, metadata in zip(results["ids"], results["metadatas"])
]
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
) -> List[ScoredResult]:
"""
Search for items in a collection using either a text or a vector query.
Parameters:
-----------
- collection_name (str): The name of the collection in which to perform the search.
- query_text (str): Text query used for search; can be None if query_vector is
provided. (default None)
- query_vector (List[float]): Vector query used for search; can be None if
query_text is provided. (default None)
- limit (int): The maximum number of results to return; defaults to 15. (default 15)
- with_vector (bool): Whether to include vectors in the results. (default False)
- normalized (bool): Whether to normalize the distance scores before returning them.
(default True)
Returns:
--------
Returns a list of ScoredResult instances representing the search results.
"""
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
try:
collection = await self.get_collection(collection_name)
if limit == 0:
limit = await collection.count()
results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances", "embeddings"]
if with_vector
else ["metadatas", "distances"],
n_results=limit,
)
vector_list = []
for i, (id, metadata, distance) in enumerate(
zip(results["ids"][0], results["metadatas"][0], results["distances"][0])
):
item = {
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
if with_vector and "embeddings" in results:
item["vector"] = results["embeddings"][0][i]
vector_list.append(item)
# Normalize vector distance
normalized_values = normalize_distances(vector_list)
for i in range(len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [
ScoredResult(
id=row["id"],
payload=row["payload"],
score=row["score"],
vector=row.get("vector") if with_vector else None,
)
for row in vector_list
]
except Exception as e:
logger.error(f"Error in search: {str(e)}")
return []
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
) -> List[List[ScoredResult]]:
"""
Perform multiple searches in a single request for efficiency, returning results for each
query.
Parameters:
-----------
- collection_name (str): The name of the collection in which to perform the
searches.
- query_texts (List[str]): A list of text queries to be searched.
- limit (int): The maximum number of results to return for each query; defaults to
5. (default 5)
- with_vectors (bool): Whether to include vectors in the results for each query.
(default False)
Returns:
--------
Returns a list of lists of ScoredResult instances for each query's results.
"""
query_vectors = await self.embed_data(query_texts)
collection = await self.get_collection(collection_name)
results = await collection.query(
query_embeddings=query_vectors,
include=["metadatas", "distances", "embeddings"]
if with_vectors
else ["metadatas", "distances"],
n_results=limit,
)
all_results = []
for i in range(len(query_texts)):
vector_list = []
for j, (id, metadata, distance) in enumerate(
zip(results["ids"][i], results["metadatas"][i], results["distances"][i])
):
item = {
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
if with_vectors and "embeddings" in results:
item["vector"] = results["embeddings"][i][j]
vector_list.append(item)
normalized_values = normalize_distances(vector_list)
query_results = []
for j, item in enumerate(vector_list):
result = ScoredResult(
id=item["id"],
payload=item["payload"],
score=normalized_values[j],
)
if with_vectors and "embeddings" in results:
result.vector = item.get("vector")
query_results.append(result)
all_results.append(query_results)
return all_results
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
"""
Remove data points from a collection based on their IDs.
Parameters:
-----------
- collection_name (str): The name of the collection from which to delete data
points.
- data_point_ids (list[str]): A list of data point IDs to remove from the
collection.
Returns:
--------
Returns True upon successful deletion of the data points.
"""
collection = await self.get_collection(collection_name)
await collection.delete(ids=data_point_ids)
return True
async def prune(self) -> bool:
"""
Delete all collections in the ChromaDB database.
Returns:
--------
Returns True upon successful deletion of all collections.
"""
client = await self.get_connection()
collection_names = await self.get_collection_names()
for collection_name in collection_names:
await client.delete_collection(collection_name)
return True
async def get_collection_names(self) -> List[str]:
"""
Retrieve the names of all collections in the ChromaDB database.
Returns:
--------
Returns a list of collection names.
"""
client = await self.get_connection()
collections = await client.list_collections()
return [
collection.name if hasattr(collection, "name") else collection["name"]
for collection in collections
]