cognee/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py
Vasilije bb7eaa017b
feat: Group DataPoints into NodeSets (#680)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com>
Co-authored-by: Boris <boris@topoteretes.com>
Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
2025-04-19 20:21:04 +02:00

365 lines
13 KiB
Python

import json
from uuid import UUID
from typing import List, Optional
from chromadb import AsyncHttpClient, Settings
from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
logger = get_logger("ChromaDBAdapter")
class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
def model_dump(self):
data = super().model_dump()
return process_data_for_chroma(data)
def process_data_for_chroma(data):
"""Convert complex data types to a format suitable for ChromaDB storage."""
processed_data = {}
for key, value in data.items():
if isinstance(value, UUID):
processed_data[key] = str(value)
elif isinstance(value, dict):
# Store dictionaries as JSON strings with special prefix
processed_data[f"{key}__dict"] = json.dumps(value)
elif isinstance(value, list):
# Store lists as JSON strings with special prefix
processed_data[f"{key}__list"] = json.dumps(value)
elif isinstance(value, (str, int, float, bool)) or value is None:
processed_data[key] = value
else:
processed_data[key] = str(value)
return processed_data
def restore_data_from_chroma(data):
"""Restore original data structure from ChromaDB storage format."""
restored_data = {}
dict_keys = []
list_keys = []
# First, identify all special keys
for key in data.keys():
if key.endswith("__dict"):
dict_keys.append(key)
elif key.endswith("__list"):
list_keys.append(key)
else:
restored_data[key] = data[key]
# Process dictionary fields
for key in dict_keys:
original_key = key[:-6] # Remove '__dict' suffix
try:
restored_data[original_key] = json.loads(data[key])
except Exception as e:
logger.debug(f"Error restoring dictionary from JSON: {e}")
restored_data[key] = data[key]
# Process list fields
for key in list_keys:
original_key = key[:-6] # Remove '__list' suffix
try:
restored_data[original_key] = json.loads(data[key])
except Exception as e:
logger.debug(f"Error restoring list from JSON: {e}")
restored_data[key] = data[key]
return restored_data
class ChromaDBAdapter(VectorDBInterface):
name = "ChromaDB"
url: str
api_key: str
connection: AsyncHttpClient = None
def __init__(
self, url: Optional[str], api_key: Optional[str], embedding_engine: EmbeddingEngine
):
self.embedding_engine = embedding_engine
self.url = url
self.api_key = api_key
async def get_connection(self) -> AsyncHttpClient:
if self.connection is None:
settings = Settings(
chroma_client_auth_provider="token", chroma_client_auth_credentials=self.api_key
)
self.connection = await AsyncHttpClient(host=self.url, settings=settings)
return self.connection
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
client = await self.get_connection()
collections = await client.list_collections()
# In ChromaDB v0.6.0, list_collections returns collection names directly
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema=None):
client = await self.get_connection()
if not await self.has_collection(collection_name):
await client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
client = await self.get_connection()
if not await self.has_collection(collection_name):
await self.create_collection(collection_name)
collection = await client.get_collection(collection_name)
texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
embeddings = await self.embed_data(texts)
ids = [str(data_point.id) for data_point in data_points]
metadatas = []
for data_point in data_points:
metadata = get_own_properties(data_point)
metadatas.append(process_data_for_chroma(metadata))
await collection.upsert(
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
)
async def create_vector_index(self, index_name: str, index_property_name: str):
"""Create a vector index as a ChromaDB collection."""
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
"""Index data points using the specified index property."""
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
],
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
"""Retrieve data points by their IDs from a collection."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
results = await collection.get(ids=data_point_ids, include=["metadatas"])
return [
ScoredResult(
id=parse_id(id),
payload=restore_data_from_chroma(metadata),
score=0,
)
for id, metadata in zip(results["ids"], results["metadatas"])
]
async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
"""Calculate distance between query and all elements in a collection."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
client = await self.get_connection()
try:
collection = await client.get_collection(collection_name)
collection_count = await collection.count()
results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances"],
n_results=collection_count,
)
result_values = []
for i, (id, metadata, distance) in enumerate(
zip(results["ids"][0], results["metadatas"][0], results["distances"][0])
):
result_values.append(
{
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
)
normalized_values = normalize_distances(result_values)
scored_results = []
for i, result in enumerate(result_values):
scored_results.append(
ScoredResult(
id=result["id"],
payload=result["payload"],
score=normalized_values[i],
)
)
return scored_results
except Exception:
return []
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 5,
with_vector: bool = False,
normalized: bool = True,
):
"""Search for similar items in a collection using text or vector query."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
try:
client = await self.get_connection()
collection = await client.get_collection(collection_name)
results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances", "embeddings"]
if with_vector
else ["metadatas", "distances"],
n_results=limit,
)
vector_list = []
for i, (id, metadata, distance) in enumerate(
zip(results["ids"][0], results["metadatas"][0], results["distances"][0])
):
item = {
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
if with_vector and "embeddings" in results:
item["vector"] = results["embeddings"][0][i]
vector_list.append(item)
# Normalize vector distance
normalized_values = normalize_distances(vector_list)
for i in range(len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [
ScoredResult(
id=row["id"],
payload=row["payload"],
score=row["score"],
vector=row.get("vector") if with_vector else None,
)
for row in vector_list
]
except Exception as e:
logger.error(f"Error in search: {str(e)}")
return []
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
):
"""Perform multiple searches in a single request for efficiency."""
query_vectors = await self.embed_data(query_texts)
client = await self.get_connection()
collection = await client.get_collection(collection_name)
results = await collection.query(
query_embeddings=query_vectors,
include=["metadatas", "distances", "embeddings"]
if with_vectors
else ["metadatas", "distances"],
n_results=limit,
)
all_results = []
for i in range(len(query_texts)):
vector_list = []
for j, (id, metadata, distance) in enumerate(
zip(results["ids"][i], results["metadatas"][i], results["distances"][i])
):
item = {
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
if with_vectors and "embeddings" in results:
item["vector"] = results["embeddings"][i][j]
vector_list.append(item)
normalized_values = normalize_distances(vector_list)
query_results = []
for j, item in enumerate(vector_list):
result = ScoredResult(
id=item["id"],
payload=item["payload"],
score=normalized_values[j],
)
if with_vectors and "embeddings" in results:
result.vector = item.get("vector")
query_results.append(result)
all_results.append(query_results)
return all_results
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""Remove data points from a collection by their IDs."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
await collection.delete(ids=data_point_ids)
return True
async def prune(self):
"""Delete all collections in the ChromaDB database."""
client = await self.get_connection()
collections = await client.list_collections()
for collection_name in collections:
await client.delete_collection(collection_name)
return True
async def get_collection_names(self):
"""Get a list of all collection names in the database."""
client = await self.get_connection()
return await client.list_collections()