mypy: fix LanceDBAdapter mypy errors

This commit is contained in:
Daulet Amirkhanov 2025-09-03 18:25:09 +01:00
parent 4ae41fede3
commit eebca89855

View file

@ -1,12 +1,14 @@
import asyncio import asyncio
from os import path from os import path
from uuid import UUID
import lancedb import lancedb
from pydantic import BaseModel from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import get_file_storage from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
@ -30,21 +32,21 @@ class IndexSchema(DataPoint):
to include 'text'. to include 'text'.
""" """
id: str id: UUID
text: str text: str
metadata: dict = {"index_fields": ["text"]} metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class LanceDBAdapter(VectorDBInterface): class LanceDBAdapter(VectorDBInterface):
name = "LanceDB" name = "LanceDB"
url: str url: Optional[str]
api_key: str api_key: Optional[str]
connection: lancedb.AsyncConnection = None connection: lancedb.AsyncConnection = None
def __init__( def __init__(
self, self,
url: Optional[str], url: Optional[str], # TODO: consider if we want to make this required and/or api_key
api_key: Optional[str], api_key: Optional[str],
embedding_engine: EmbeddingEngine, embedding_engine: EmbeddingEngine,
): ):
@ -53,7 +55,7 @@ class LanceDBAdapter(VectorDBInterface):
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.VECTOR_DB_LOCK = asyncio.Lock() self.VECTOR_DB_LOCK = asyncio.Lock()
async def get_connection(self): async def get_connection(self) -> lancedb.AsyncConnection:
""" """
Establishes and returns a connection to the LanceDB. Establishes and returns a connection to the LanceDB.
@ -107,12 +109,9 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names() collection_names = await connection.table_names()
return collection_name in collection_names return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel): async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel): class LanceDataPoint(LanceModel):
""" """
Represents a data point in the Lance model with an ID, vector, and associated payload. Represents a data point in the Lance model with an ID, vector, and associated payload.
@ -123,28 +122,28 @@ class LanceDBAdapter(VectorDBInterface):
- payload: Additional data or metadata associated with the data point. - payload: Additional data or metadata associated with the data point.
""" """
id: data_point_types["id"] id: UUID
vector: Vector(vector_size) vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic
payload: payload_schema payload: Dict[str, Any]
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK: async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
connection = await self.get_connection() connection = await self.get_connection()
return await connection.create_table( await connection.create_table(
name=collection_name, name=collection_name,
schema=LanceDataPoint, schema=LanceDataPoint,
exist_ok=True, exist_ok=True,
) )
async def get_collection(self, collection_name: str): async def get_collection(self, collection_name: str) -> Any:
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!") raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
connection = await self.get_connection() connection = await self.get_connection()
return await connection.open_table(collection_name) return await connection.open_table(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
payload_schema = type(data_points[0]) payload_schema = type(data_points[0])
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
@ -175,14 +174,14 @@ class LanceDBAdapter(VectorDBInterface):
""" """
id: IdType id: IdType
vector: Vector(vector_size) vector: Vector[vector_size]
payload: PayloadSchema payload: PayloadSchema
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint: def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
properties = get_own_properties(data_point) properties = get_own_properties(data_point)
properties["id"] = str(properties["id"]) properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))]( return LanceDataPoint(
id=str(data_point.id), id=str(data_point.id),
vector=vector, vector=vector,
payload=properties, payload=properties,
@ -201,7 +200,7 @@ class LanceDBAdapter(VectorDBInterface):
.execute(lance_data_points) .execute(lance_data_points)
) )
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)
if len(data_point_ids) == 1: if len(data_point_ids) == 1:
@ -221,12 +220,12 @@ class LanceDBAdapter(VectorDBInterface):
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: Optional[str] = None,
query_vector: List[float] = None, query_vector: Optional[List[float]] = None,
limit: int = 15, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ) -> List[ScoredResult]:
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise MissingQueryParameterError() raise MissingQueryParameterError()
@ -264,9 +263,9 @@ class LanceDBAdapter(VectorDBInterface):
self, self,
collection_name: str, collection_name: str,
query_texts: List[str], query_texts: List[str],
limit: int = None, limit: Optional[int] = None,
with_vectors: bool = False, with_vectors: bool = False,
): ) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts) query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather( return await asyncio.gather(
@ -274,40 +273,44 @@ class LanceDBAdapter(VectorDBInterface):
self.search( self.search(
collection_name=collection_name, collection_name=collection_name,
query_vector=query_vector, query_vector=query_vector,
limit=limit, limit=limit or 15,
with_vector=with_vectors, with_vector=with_vectors,
) )
for query_vector in query_vectors for query_vector in query_vectors
] ]
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)
# Delete one at a time to avoid commit conflicts # Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids: for data_point_id in data_point_ids:
await collection.delete(f"id = '{data_point_id}'") await collection.delete(f"id = '{data_point_id}'")
async def create_vector_index(self, index_name: str, index_property_name: str): async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection( await self.create_collection(
f"{index_name}_{index_property_name}", payload_schema=IndexSchema f"{index_name}_{index_property_name}", payload_schema=IndexSchema
) )
async def index_data_points( async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint] self, index_name: str, index_property_name: str, data_points: List[DataPoint]
): ) -> None:
await self.create_data_points( await self.create_data_points(
f"{index_name}_{index_property_name}", f"{index_name}_{index_property_name}",
[ [
IndexSchema( IndexSchema(
id=str(data_point.id), id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]), text=getattr(
data_point,
data_point.metadata["index_fields"][0]
),
) )
for data_point in data_points for data_point in data_points
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
], ],
) )
async def prune(self): async def prune(self) -> None:
connection = await self.get_connection() connection = await self.get_connection()
collection_names = await connection.table_names() collection_names = await connection.table_names()
@ -316,12 +319,15 @@ class LanceDBAdapter(VectorDBInterface):
await collection.delete("id IS NOT NULL") await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name) await connection.drop_table(collection_name)
if self.url.startswith("/"): if self.url and self.url.startswith("/"):
db_dir_path = path.dirname(self.url) db_dir_path = path.dirname(self.url)
db_file_name = path.basename(self.url) db_file_name = path.basename(self.url)
await get_file_storage(db_dir_path).remove_all(db_file_name) await get_file_storage(db_dir_path).remove_all(db_file_name)
def get_data_point_schema(self, model_type: BaseModel): def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
if model_type is None:
return DataPoint
related_models_fields = [] related_models_fields = []
for field_name, field_config in model_type.model_fields.items(): for field_name, field_config in model_type.model_fields.items():