From c8dbe0ee38ca5236266a037bf9c0686ead6e043d Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Wed, 3 Sep 2025 18:10:01 +0100 Subject: [PATCH] mypy fix: Fix ChromaDBAdapter mypy errors --- .../vector/chromadb/ChromaDBAdapter.py | 53 ++++++++++--------- .../databases/vector/models/ScoredResult.py | 4 +- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index aec33abe2..e89ac3193 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -1,12 +1,13 @@ import json import asyncio from uuid import UUID -from typing import List, Optional +from typing import List, Optional, Dict, Any from chromadb import AsyncHttpClient, Settings from cognee.shared.logging_utils import get_logger from cognee.modules.storage.utils import get_own_properties from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.DataPoint import MetaData from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult @@ -35,9 +36,9 @@ class IndexSchema(DataPoint): text: str - metadata: dict = {"index_fields": ["text"]} + metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"} - def model_dump(self): + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: """ Serialize the instance data for storage. @@ -49,11 +50,11 @@ class IndexSchema(DataPoint): A dictionary containing serialized data processed for ChromaDB storage. """ - data = super().model_dump() + data = super().model_dump(**kwargs) return process_data_for_chroma(data) -def process_data_for_chroma(data): +def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]: """ Convert complex data types to a format suitable for ChromaDB storage. @@ -73,7 +74,7 @@ def process_data_for_chroma(data): A dictionary containing the processed key-value pairs suitable for ChromaDB storage. """ - processed_data = {} + processed_data: Dict[str, Any] = {} for key, value in data.items(): if isinstance(value, UUID): processed_data[key] = str(value) @@ -90,7 +91,7 @@ def process_data_for_chroma(data): return processed_data -def restore_data_from_chroma(data): +def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]: """ Restore original data structure from ChromaDB storage format. @@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface): """ name = "ChromaDB" - url: str - api_key: str + url: str | None + api_key: str | None connection: AsyncHttpClient = None def __init__( @@ -216,7 +217,7 @@ class ChromaDBAdapter(VectorDBInterface): collections = await self.get_collection_names() return collection_name in collections - async def create_collection(self, collection_name: str, payload_schema=None): + async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: """ Create a new collection in ChromaDB if it does not already exist. @@ -254,7 +255,7 @@ class ChromaDBAdapter(VectorDBInterface): client = await self.get_connection() return await client.get_collection(collection_name) - async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): + async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None: """ Create and upsert data points into the specified collection in ChromaDB. @@ -282,7 +283,7 @@ class ChromaDBAdapter(VectorDBInterface): ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts ) - async def create_vector_index(self, index_name: str, index_property_name: str): + async def create_vector_index(self, index_name: str, index_property_name: str) -> None: """ Create a vector index as a ChromaDB collection based on provided names. @@ -296,7 +297,7 @@ class ChromaDBAdapter(VectorDBInterface): async def index_data_points( self, index_name: str, index_property_name: str, data_points: list[DataPoint] - ): + ) -> None: """ Index the provided data points based on the specified index property in ChromaDB. @@ -312,13 +313,17 @@ class ChromaDBAdapter(VectorDBInterface): [ IndexSchema( id=data_point.id, - text=getattr(data_point, data_point.metadata["index_fields"][0]), + text=getattr( + data_point, + data_point.metadata["index_fields"][0] + ), ) for data_point in data_points + if data_point.metadata and len(data_point.metadata["index_fields"]) > 0 ], ) - async def retrieve(self, collection_name: str, data_point_ids: list[str]): + async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]: """ Retrieve data points by their IDs from a ChromaDB collection. @@ -350,12 +355,12 @@ class ChromaDBAdapter(VectorDBInterface): async def search( self, collection_name: str, - query_text: str = None, - query_vector: List[float] = None, + query_text: Optional[str] = None, + query_vector: Optional[List[float]] = None, limit: int = 15, with_vector: bool = False, normalized: bool = True, - ): + ) -> List[ScoredResult]: """ Search for items in a collection using either a text or a vector query. @@ -437,7 +442,7 @@ class ChromaDBAdapter(VectorDBInterface): query_texts: List[str], limit: int = 5, with_vectors: bool = False, - ): + ) -> List[List[ScoredResult]]: """ Perform multiple searches in a single request for efficiency, returning results for each query. @@ -507,7 +512,7 @@ class ChromaDBAdapter(VectorDBInterface): return all_results - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool: """ Remove data points from a collection based on their IDs. @@ -528,7 +533,7 @@ class ChromaDBAdapter(VectorDBInterface): await collection.delete(ids=data_point_ids) return True - async def prune(self): + async def prune(self) -> bool: """ Delete all collections in the ChromaDB database. @@ -538,12 +543,12 @@ class ChromaDBAdapter(VectorDBInterface): Returns True upon successful deletion of all collections. """ client = await self.get_connection() - collections = await client.list_collections() - for collection_name in collections: + collection_names = await self.get_collection_names() + for collection_name in collection_names: await client.delete_collection(collection_name) return True - async def get_collection_names(self): + async def get_collection_names(self) -> List[str]: """ Retrieve the names of all collections in the ChromaDB database. diff --git a/cognee/infrastructure/databases/vector/models/ScoredResult.py b/cognee/infrastructure/databases/vector/models/ScoredResult.py index 0a8cc9888..690547fd2 100644 --- a/cognee/infrastructure/databases/vector/models/ScoredResult.py +++ b/cognee/infrastructure/databases/vector/models/ScoredResult.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Optional from uuid import UUID from pydantic import BaseModel @@ -14,8 +14,10 @@ class ScoredResult(BaseModel): better outcome. - payload (Dict[str, Any]): Additional information related to the score, stored as key-value pairs in a dictionary. + - vector (Optional[List[float]]): Optional vector embedding associated with the result. """ id: UUID score: float # Lower score is better payload: Dict[str, Any] + vector: Optional[List[float]] = None