mypy fix: Fix ChromaDBAdapter mypy errors

This commit is contained in:
Daulet Amirkhanov 2025-09-03 18:10:01 +01:00
parent 26f5ab4f0f
commit 4ae41fede3
2 changed files with 32 additions and 25 deletions

View file

@ -1,12 +1,13 @@
import json import json
import asyncio import asyncio
from uuid import UUID from uuid import UUID
from typing import List, Optional from typing import List, Optional, Dict, Any
from chromadb import AsyncHttpClient, Settings from chromadb import AsyncHttpClient, Settings
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
text: str text: str
metadata: dict = {"index_fields": ["text"]} metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
def model_dump(self): def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
""" """
Serialize the instance data for storage. Serialize the instance data for storage.
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
A dictionary containing serialized data processed for ChromaDB storage. A dictionary containing serialized data processed for ChromaDB storage.
""" """
data = super().model_dump() data = super().model_dump(**kwargs)
return process_data_for_chroma(data) return process_data_for_chroma(data)
def process_data_for_chroma(data): def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
""" """
Convert complex data types to a format suitable for ChromaDB storage. Convert complex data types to a format suitable for ChromaDB storage.
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
A dictionary containing the processed key-value pairs suitable for ChromaDB storage. A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
""" """
processed_data = {} processed_data: Dict[str, Any] = {}
for key, value in data.items(): for key, value in data.items():
if isinstance(value, UUID): if isinstance(value, UUID):
processed_data[key] = str(value) processed_data[key] = str(value)
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
return processed_data return processed_data
def restore_data_from_chroma(data): def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
""" """
Restore original data structure from ChromaDB storage format. Restore original data structure from ChromaDB storage format.
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
""" """
name = "ChromaDB" name = "ChromaDB"
url: str url: str | None
api_key: str api_key: str | None
connection: AsyncHttpClient = None connection: AsyncHttpClient = None
def __init__( def __init__(
@ -216,7 +217,7 @@ class ChromaDBAdapter(VectorDBInterface):
collections = await self.get_collection_names() collections = await self.get_collection_names()
return collection_name in collections return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
""" """
Create a new collection in ChromaDB if it does not already exist. Create a new collection in ChromaDB if it does not already exist.
@ -254,7 +255,7 @@ class ChromaDBAdapter(VectorDBInterface):
client = await self.get_connection() client = await self.get_connection()
return await client.get_collection(collection_name) return await client.get_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
""" """
Create and upsert data points into the specified collection in ChromaDB. Create and upsert data points into the specified collection in ChromaDB.
@ -282,7 +283,7 @@ class ChromaDBAdapter(VectorDBInterface):
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
) )
async def create_vector_index(self, index_name: str, index_property_name: str): async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
""" """
Create a vector index as a ChromaDB collection based on provided names. Create a vector index as a ChromaDB collection based on provided names.
@ -296,7 +297,7 @@ class ChromaDBAdapter(VectorDBInterface):
async def index_data_points( async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint] self, index_name: str, index_property_name: str, data_points: list[DataPoint]
): ) -> None:
""" """
Index the provided data points based on the specified index property in ChromaDB. Index the provided data points based on the specified index property in ChromaDB.
@ -312,13 +313,17 @@ class ChromaDBAdapter(VectorDBInterface):
[ [
IndexSchema( IndexSchema(
id=data_point.id, id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]), text=getattr(
data_point,
data_point.metadata["index_fields"][0]
),
) )
for data_point in data_points for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
], ],
) )
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
""" """
Retrieve data points by their IDs from a ChromaDB collection. Retrieve data points by their IDs from a ChromaDB collection.
@ -350,12 +355,12 @@ class ChromaDBAdapter(VectorDBInterface):
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: Optional[str] = None,
query_vector: List[float] = None, query_vector: Optional[List[float]] = None,
limit: int = 15, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ) -> List[ScoredResult]:
""" """
Search for items in a collection using either a text or a vector query. Search for items in a collection using either a text or a vector query.
@ -437,7 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str], query_texts: List[str],
limit: int = 5, limit: int = 5,
with_vectors: bool = False, with_vectors: bool = False,
): ) -> List[List[ScoredResult]]:
""" """
Perform multiple searches in a single request for efficiency, returning results for each Perform multiple searches in a single request for efficiency, returning results for each
query. query.
@ -507,7 +512,7 @@ class ChromaDBAdapter(VectorDBInterface):
return all_results return all_results
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
""" """
Remove data points from a collection based on their IDs. Remove data points from a collection based on their IDs.
@ -528,7 +533,7 @@ class ChromaDBAdapter(VectorDBInterface):
await collection.delete(ids=data_point_ids) await collection.delete(ids=data_point_ids)
return True return True
async def prune(self): async def prune(self) -> bool:
""" """
Delete all collections in the ChromaDB database. Delete all collections in the ChromaDB database.
@ -538,12 +543,12 @@ class ChromaDBAdapter(VectorDBInterface):
Returns True upon successful deletion of all collections. Returns True upon successful deletion of all collections.
""" """
client = await self.get_connection() client = await self.get_connection()
collections = await self.list_collections() collection_names = await self.get_collection_names()
for collection_name in collections: for collection_name in collection_names:
await client.delete_collection(collection_name) await client.delete_collection(collection_name)
return True return True
async def get_collection_names(self): async def get_collection_names(self) -> List[str]:
""" """
Retrieve the names of all collections in the ChromaDB database. Retrieve the names of all collections in the ChromaDB database.

View file

@ -1,4 +1,4 @@
from typing import Any, Dict from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
better outcome. better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as - payload (Dict[str, Any]): Additional information related to the score, stored as
key-value pairs in a dictionary. key-value pairs in a dictionary.
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
""" """
id: UUID id: UUID
score: float # Lower score is better score: float # Lower score is better
payload: Dict[str, Any] payload: Dict[str, Any]
vector: Optional[List[float]] = None