mypy fix: Fix ChromaDBAdapter mypy errors
This commit is contained in:
parent
26f5ab4f0f
commit
4ae41fede3
2 changed files with 32 additions and 25 deletions
|
|
@ -1,12 +1,13 @@
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict, Any
|
||||||
from chromadb import AsyncHttpClient, Settings
|
from chromadb import AsyncHttpClient, Settings
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.storage.utils import get_own_properties
|
from cognee.modules.storage.utils import get_own_properties
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import MetaData
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||||
|
|
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||||
|
|
||||||
def model_dump(self):
|
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Serialize the instance data for storage.
|
Serialize the instance data for storage.
|
||||||
|
|
||||||
|
|
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
|
||||||
|
|
||||||
A dictionary containing serialized data processed for ChromaDB storage.
|
A dictionary containing serialized data processed for ChromaDB storage.
|
||||||
"""
|
"""
|
||||||
data = super().model_dump()
|
data = super().model_dump(**kwargs)
|
||||||
return process_data_for_chroma(data)
|
return process_data_for_chroma(data)
|
||||||
|
|
||||||
|
|
||||||
def process_data_for_chroma(data):
|
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert complex data types to a format suitable for ChromaDB storage.
|
Convert complex data types to a format suitable for ChromaDB storage.
|
||||||
|
|
||||||
|
|
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
|
||||||
|
|
||||||
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
|
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
|
||||||
"""
|
"""
|
||||||
processed_data = {}
|
processed_data: Dict[str, Any] = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if isinstance(value, UUID):
|
if isinstance(value, UUID):
|
||||||
processed_data[key] = str(value)
|
processed_data[key] = str(value)
|
||||||
|
|
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
|
||||||
return processed_data
|
return processed_data
|
||||||
|
|
||||||
|
|
||||||
def restore_data_from_chroma(data):
|
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Restore original data structure from ChromaDB storage format.
|
Restore original data structure from ChromaDB storage format.
|
||||||
|
|
||||||
|
|
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "ChromaDB"
|
name = "ChromaDB"
|
||||||
url: str
|
url: str | None
|
||||||
api_key: str
|
api_key: str | None
|
||||||
connection: AsyncHttpClient = None
|
connection: AsyncHttpClient = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -216,7 +217,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
collections = await self.get_collection_names()
|
collections = await self.get_collection_names()
|
||||||
return collection_name in collections
|
return collection_name in collections
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new collection in ChromaDB if it does not already exist.
|
Create a new collection in ChromaDB if it does not already exist.
|
||||||
|
|
||||||
|
|
@ -254,7 +255,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
client = await self.get_connection()
|
client = await self.get_connection()
|
||||||
return await client.get_collection(collection_name)
|
return await client.get_collection(collection_name)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Create and upsert data points into the specified collection in ChromaDB.
|
Create and upsert data points into the specified collection in ChromaDB.
|
||||||
|
|
||||||
|
|
@ -282,7 +283,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
|
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Create a vector index as a ChromaDB collection based on provided names.
|
Create a vector index as a ChromaDB collection based on provided names.
|
||||||
|
|
||||||
|
|
@ -296,7 +297,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
async def index_data_points(
|
async def index_data_points(
|
||||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Index the provided data points based on the specified index property in ChromaDB.
|
Index the provided data points based on the specified index property in ChromaDB.
|
||||||
|
|
||||||
|
|
@ -312,13 +313,17 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
[
|
[
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id=data_point.id,
|
id=data_point.id,
|
||||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
text=getattr(
|
||||||
|
data_point,
|
||||||
|
data_point.metadata["index_fields"][0]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
|
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
|
||||||
"""
|
"""
|
||||||
Retrieve data points by their IDs from a ChromaDB collection.
|
Retrieve data points by their IDs from a ChromaDB collection.
|
||||||
|
|
||||||
|
|
@ -350,12 +355,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: str = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: List[float] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = 15,
|
limit: int = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
):
|
) -> List[ScoredResult]:
|
||||||
"""
|
"""
|
||||||
Search for items in a collection using either a text or a vector query.
|
Search for items in a collection using either a text or a vector query.
|
||||||
|
|
||||||
|
|
@ -437,7 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
) -> List[List[ScoredResult]]:
|
||||||
"""
|
"""
|
||||||
Perform multiple searches in a single request for efficiency, returning results for each
|
Perform multiple searches in a single request for efficiency, returning results for each
|
||||||
query.
|
query.
|
||||||
|
|
@ -507,7 +512,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
|
||||||
"""
|
"""
|
||||||
Remove data points from a collection based on their IDs.
|
Remove data points from a collection based on their IDs.
|
||||||
|
|
||||||
|
|
@ -528,7 +533,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
await collection.delete(ids=data_point_ids)
|
await collection.delete(ids=data_point_ids)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete all collections in the ChromaDB database.
|
Delete all collections in the ChromaDB database.
|
||||||
|
|
||||||
|
|
@ -538,12 +543,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
Returns True upon successful deletion of all collections.
|
Returns True upon successful deletion of all collections.
|
||||||
"""
|
"""
|
||||||
client = await self.get_connection()
|
client = await self.get_connection()
|
||||||
collections = await self.list_collections()
|
collection_names = await self.get_collection_names()
|
||||||
for collection_name in collections:
|
for collection_name in collection_names:
|
||||||
await client.delete_collection(collection_name)
|
await client.delete_collection(collection_name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_collection_names(self):
|
async def get_collection_names(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Retrieve the names of all collections in the ChromaDB database.
|
Retrieve the names of all collections in the ChromaDB database.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
|
||||||
better outcome.
|
better outcome.
|
||||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
||||||
key-value pairs in a dictionary.
|
key-value pairs in a dictionary.
|
||||||
|
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
score: float # Lower score is better
|
score: float # Lower score is better
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
|
vector: Optional[List[float]] = None
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue