diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 0184ec3ee..c5f0bf0b6 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -1,12 +1,14 @@ import asyncio from os import path +from uuid import UUID import lancedb from pydantic import BaseModel from lancedb.pydantic import LanceModel, Vector -from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any from cognee.infrastructure.databases.exceptions import MissingQueryParameterError from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.DataPoint import MetaData from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.files.storage import get_file_storage from cognee.modules.storage.utils import copy_model, get_own_properties @@ -30,21 +32,21 @@ class IndexSchema(DataPoint): to include 'text'. """ - id: str + id: UUID text: str - metadata: dict = {"index_fields": ["text"]} + metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"} class LanceDBAdapter(VectorDBInterface): name = "LanceDB" - url: str - api_key: str + url: Optional[str] + api_key: Optional[str] connection: lancedb.AsyncConnection = None def __init__( self, - url: Optional[str], + url: Optional[str], # TODO: consider if we want to make this required and/or api_key api_key: Optional[str], embedding_engine: EmbeddingEngine, ): @@ -53,7 +55,7 @@ class LanceDBAdapter(VectorDBInterface): self.embedding_engine = embedding_engine self.VECTOR_DB_LOCK = asyncio.Lock() - async def get_connection(self): + async def get_connection(self) -> lancedb.AsyncConnection: """ Establishes and returns a connection to the LanceDB. @@ -107,12 +109,9 @@ class LanceDBAdapter(VectorDBInterface): collection_names = await connection.table_names() return collection_name in collection_names - async def create_collection(self, collection_name: str, payload_schema: BaseModel): + async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None: vector_size = self.embedding_engine.get_vector_size() - payload_schema = self.get_data_point_schema(payload_schema) - data_point_types = get_type_hints(payload_schema) - class LanceDataPoint(LanceModel): """ Represents a data point in the Lance model with an ID, vector, and associated payload. @@ -123,28 +122,28 @@ class LanceDBAdapter(VectorDBInterface): - payload: Additional data or metadata associated with the data point. """ - id: data_point_types["id"] - vector: Vector(vector_size) - payload: payload_schema + id: UUID + vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic + payload: Dict[str, Any] if not await self.has_collection(collection_name): async with self.VECTOR_DB_LOCK: if not await self.has_collection(collection_name): connection = await self.get_connection() - return await connection.create_table( + await connection.create_table( name=collection_name, schema=LanceDataPoint, exist_ok=True, ) - async def get_collection(self, collection_name: str): + async def get_collection(self, collection_name: str) -> Any: if not await self.has_collection(collection_name): raise CollectionNotFoundError(f"Collection '{collection_name}' not found!") connection = await self.get_connection() return await connection.open_table(collection_name) - async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): + async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None: payload_schema = type(data_points[0]) if not await self.has_collection(collection_name): @@ -175,14 +174,14 @@ class LanceDBAdapter(VectorDBInterface): """ id: IdType - vector: Vector(vector_size) + vector: Vector[vector_size] payload: PayloadSchema - def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint: + def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any: properties = get_own_properties(data_point) properties["id"] = str(properties["id"]) - return LanceDataPoint[str, self.get_data_point_schema(type(data_point))]( + return LanceDataPoint( id=str(data_point.id), vector=vector, payload=properties, @@ -201,7 +200,7 @@ class LanceDBAdapter(VectorDBInterface): .execute(lance_data_points) ) - async def retrieve(self, collection_name: str, data_point_ids: list[str]): + async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]: collection = await self.get_collection(collection_name) if len(data_point_ids) == 1: @@ -221,12 +220,12 @@ class LanceDBAdapter(VectorDBInterface): async def search( self, collection_name: str, - query_text: str = None, - query_vector: List[float] = None, + query_text: Optional[str] = None, + query_vector: Optional[List[float]] = None, limit: int = 15, with_vector: bool = False, normalized: bool = True, - ): + ) -> List[ScoredResult]: if query_text is None and query_vector is None: raise MissingQueryParameterError() @@ -264,9 +263,9 @@ class LanceDBAdapter(VectorDBInterface): self, collection_name: str, query_texts: List[str], - limit: int = None, + limit: Optional[int] = None, with_vectors: bool = False, - ): + ) -> List[List[ScoredResult]]: query_vectors = await self.embedding_engine.embed_text(query_texts) return await asyncio.gather( @@ -274,40 +273,44 @@ class LanceDBAdapter(VectorDBInterface): self.search( collection_name=collection_name, query_vector=query_vector, - limit=limit, + limit=limit or 15, with_vector=with_vectors, ) for query_vector in query_vectors ] ) - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None: collection = await self.get_collection(collection_name) # Delete one at a time to avoid commit conflicts for data_point_id in data_point_ids: await collection.delete(f"id = '{data_point_id}'") - async def create_vector_index(self, index_name: str, index_property_name: str): + async def create_vector_index(self, index_name: str, index_property_name: str) -> None: await self.create_collection( f"{index_name}_{index_property_name}", payload_schema=IndexSchema ) async def index_data_points( - self, index_name: str, index_property_name: str, data_points: list[DataPoint] - ): + self, index_name: str, index_property_name: str, data_points: List[DataPoint] + ) -> None: await self.create_data_points( f"{index_name}_{index_property_name}", [ IndexSchema( - id=str(data_point.id), - text=getattr(data_point, data_point.metadata["index_fields"][0]), + id=data_point.id, + text=getattr( + data_point, + data_point.metadata["index_fields"][0] + ), ) for data_point in data_points + if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0 ], ) - async def prune(self): + async def prune(self) -> None: connection = await self.get_connection() collection_names = await connection.table_names() @@ -316,12 +319,15 @@ class LanceDBAdapter(VectorDBInterface): await collection.delete("id IS NOT NULL") await connection.drop_table(collection_name) - if self.url.startswith("/"): + if self.url and self.url.startswith("/"): db_dir_path = path.dirname(self.url) db_file_name = path.basename(self.url) await get_file_storage(db_dir_path).remove_all(db_file_name) - def get_data_point_schema(self, model_type: BaseModel): + def get_data_point_schema(self, model_type: Optional[Any]) -> Any: + if model_type is None: + return DataPoint + related_models_fields = [] for field_name, field_config in model_type.model_fields.items():