mypy fix: Fix ChromaDBAdapter mypy errors
This commit is contained in:
parent
2992a38acd
commit
c8dbe0ee38
2 changed files with 32 additions and 25 deletions
|
|
@ -1,12 +1,13 @@
|
|||
import json
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from chromadb import AsyncHttpClient, Settings
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.storage.utils import get_own_properties
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.DataPoint import MetaData
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
|
|
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
|
|||
|
||||
text: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
|
||||
def model_dump(self):
|
||||
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize the instance data for storage.
|
||||
|
||||
|
|
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
|
|||
|
||||
A dictionary containing serialized data processed for ChromaDB storage.
|
||||
"""
|
||||
data = super().model_dump()
|
||||
data = super().model_dump(**kwargs)
|
||||
return process_data_for_chroma(data)
|
||||
|
||||
|
||||
def process_data_for_chroma(data):
|
||||
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert complex data types to a format suitable for ChromaDB storage.
|
||||
|
||||
|
|
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
|
|||
|
||||
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
|
||||
"""
|
||||
processed_data = {}
|
||||
processed_data: Dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, UUID):
|
||||
processed_data[key] = str(value)
|
||||
|
|
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
|
|||
return processed_data
|
||||
|
||||
|
||||
def restore_data_from_chroma(data):
|
||||
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Restore original data structure from ChromaDB storage format.
|
||||
|
||||
|
|
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
"""
|
||||
|
||||
name = "ChromaDB"
|
||||
url: str
|
||||
api_key: str
|
||||
url: str | None
|
||||
api_key: str | None
|
||||
connection: AsyncHttpClient = None
|
||||
|
||||
def __init__(
|
||||
|
|
@ -216,7 +217,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
collections = await self.get_collection_names()
|
||||
return collection_name in collections
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
||||
"""
|
||||
Create a new collection in ChromaDB if it does not already exist.
|
||||
|
||||
|
|
@ -254,7 +255,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
client = await self.get_connection()
|
||||
return await client.get_collection(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||
"""
|
||||
Create and upsert data points into the specified collection in ChromaDB.
|
||||
|
||||
|
|
@ -282,7 +283,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
|
||||
)
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||
"""
|
||||
Create a vector index as a ChromaDB collection based on provided names.
|
||||
|
||||
|
|
@ -296,7 +297,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Index the provided data points based on the specified index property in ChromaDB.
|
||||
|
||||
|
|
@ -312,13 +313,17 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
[
|
||||
IndexSchema(
|
||||
id=data_point.id,
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
text=getattr(
|
||||
data_point,
|
||||
data_point.metadata["index_fields"][0]
|
||||
),
|
||||
)
|
||||
for data_point in data_points
|
||||
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
|
||||
"""
|
||||
Retrieve data points by their IDs from a ChromaDB collection.
|
||||
|
||||
|
|
@ -350,12 +355,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
) -> List[ScoredResult]:
|
||||
"""
|
||||
Search for items in a collection using either a text or a vector query.
|
||||
|
||||
|
|
@ -437,7 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
query_texts: List[str],
|
||||
limit: int = 5,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
) -> List[List[ScoredResult]]:
|
||||
"""
|
||||
Perform multiple searches in a single request for efficiency, returning results for each
|
||||
query.
|
||||
|
|
@ -507,7 +512,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
|
||||
return all_results
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
|
||||
"""
|
||||
Remove data points from a collection based on their IDs.
|
||||
|
||||
|
|
@ -528,7 +533,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
await collection.delete(ids=data_point_ids)
|
||||
return True
|
||||
|
||||
async def prune(self):
|
||||
async def prune(self) -> bool:
|
||||
"""
|
||||
Delete all collections in the ChromaDB database.
|
||||
|
||||
|
|
@ -538,12 +543,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
Returns True upon successful deletion of all collections.
|
||||
"""
|
||||
client = await self.get_connection()
|
||||
collections = await client.list_collections()
|
||||
for collection_name in collections:
|
||||
collection_names = await self.get_collection_names()
|
||||
for collection_name in collection_names:
|
||||
await client.delete_collection(collection_name)
|
||||
return True
|
||||
|
||||
async def get_collection_names(self):
|
||||
async def get_collection_names(self) -> List[str]:
|
||||
"""
|
||||
Retrieve the names of all collections in the ChromaDB database.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
|
|||
better outcome.
|
||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
||||
key-value pairs in a dictionary.
|
||||
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
score: float # Lower score is better
|
||||
payload: Dict[str, Any]
|
||||
vector: Optional[List[float]] = None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue