mypy fix: Fix ChromaDBAdapter mypy errors

This commit is contained in:
Daulet Amirkhanov 2025-09-03 18:10:01 +01:00
parent 2992a38acd
commit c8dbe0ee38
2 changed files with 32 additions and 25 deletions

View file

@ -1,12 +1,13 @@
import json
import asyncio
from uuid import UUID
from typing import List, Optional
from typing import List, Optional, Dict, Any
from chromadb import AsyncHttpClient, Settings
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
def model_dump(self):
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""
Serialize the instance data for storage.
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
A dictionary containing serialized data processed for ChromaDB storage.
"""
data = super().model_dump()
data = super().model_dump(**kwargs)
return process_data_for_chroma(data)
def process_data_for_chroma(data):
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert complex data types to a format suitable for ChromaDB storage.
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
"""
processed_data = {}
processed_data: Dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UUID):
processed_data[key] = str(value)
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
return processed_data
def restore_data_from_chroma(data):
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Restore original data structure from ChromaDB storage format.
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
"""
name = "ChromaDB"
url: str
api_key: str
url: str | None
api_key: str | None
connection: AsyncHttpClient = None
def __init__(
@ -216,7 +217,7 @@ class ChromaDBAdapter(VectorDBInterface):
collections = await self.get_collection_names()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema=None):
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
"""
Create a new collection in ChromaDB if it does not already exist.
@ -254,7 +255,7 @@ class ChromaDBAdapter(VectorDBInterface):
client = await self.get_connection()
return await client.get_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
"""
Create and upsert data points into the specified collection in ChromaDB.
@ -282,7 +283,7 @@ class ChromaDBAdapter(VectorDBInterface):
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
)
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
"""
Create a vector index as a ChromaDB collection based on provided names.
@ -296,7 +297,7 @@ class ChromaDBAdapter(VectorDBInterface):
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
) -> None:
"""
Index the provided data points based on the specified index property in ChromaDB.
@ -312,13 +313,17 @@ class ChromaDBAdapter(VectorDBInterface):
[
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]),
text=getattr(
data_point,
data_point.metadata["index_fields"][0]
),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
],
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
"""
Retrieve data points by their IDs from a ChromaDB collection.
@ -350,12 +355,12 @@ class ChromaDBAdapter(VectorDBInterface):
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
) -> List[ScoredResult]:
"""
Search for items in a collection using either a text or a vector query.
@ -437,7 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
"""
Perform multiple searches in a single request for efficiency, returning results for each
query.
@ -507,7 +512,7 @@ class ChromaDBAdapter(VectorDBInterface):
return all_results
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
"""
Remove data points from a collection based on their IDs.
@ -528,7 +533,7 @@ class ChromaDBAdapter(VectorDBInterface):
await collection.delete(ids=data_point_ids)
return True
async def prune(self):
async def prune(self) -> bool:
"""
Delete all collections in the ChromaDB database.
@ -538,12 +543,12 @@ class ChromaDBAdapter(VectorDBInterface):
Returns True upon successful deletion of all collections.
"""
client = await self.get_connection()
collections = await client.list_collections()
for collection_name in collections:
collection_names = await self.get_collection_names()
for collection_name in collection_names:
await client.delete_collection(collection_name)
return True
async def get_collection_names(self):
async def get_collection_names(self) -> List[str]:
"""
Retrieve the names of all collections in the ChromaDB database.

View file

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as
key-value pairs in a dictionary.
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
"""
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]
vector: Optional[List[float]] = None