mypy: fix LanceDBAdapter mypy errors

This commit is contained in:
Daulet Amirkhanov 2025-09-03 18:25:09 +01:00
parent 4ae41fede3
commit eebca89855

View file

@ -1,12 +1,14 @@
import asyncio
from os import path
from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.storage.utils import copy_model, get_own_properties
@ -30,21 +32,21 @@ class IndexSchema(DataPoint):
to include 'text'.
"""
id: str
id: UUID
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class LanceDBAdapter(VectorDBInterface):
name = "LanceDB"
url: str
api_key: str
url: Optional[str]
api_key: Optional[str]
connection: lancedb.AsyncConnection = None
def __init__(
self,
url: Optional[str],
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
@ -53,7 +55,7 @@ class LanceDBAdapter(VectorDBInterface):
self.embedding_engine = embedding_engine
self.VECTOR_DB_LOCK = asyncio.Lock()
async def get_connection(self):
async def get_connection(self) -> lancedb.AsyncConnection:
"""
Establishes and returns a connection to the LanceDB.
@ -107,12 +109,9 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names()
return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel):
"""
Represents a data point in the Lance model with an ID, vector, and associated payload.
@ -123,28 +122,28 @@ class LanceDBAdapter(VectorDBInterface):
- payload: Additional data or metadata associated with the data point.
"""
id: data_point_types["id"]
vector: Vector(vector_size)
payload: payload_schema
id: UUID
vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic
payload: Dict[str, Any]
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
connection = await self.get_connection()
return await connection.create_table(
await connection.create_table(
name=collection_name,
schema=LanceDataPoint,
exist_ok=True,
)
async def get_collection(self, collection_name: str):
async def get_collection(self, collection_name: str) -> Any:
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
connection = await self.get_connection()
return await connection.open_table(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
payload_schema = type(data_points[0])
if not await self.has_collection(collection_name):
@ -175,14 +174,14 @@ class LanceDBAdapter(VectorDBInterface):
"""
id: IdType
vector: Vector(vector_size)
vector: Vector[vector_size]
payload: PayloadSchema
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
properties = get_own_properties(data_point)
properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
return LanceDataPoint(
id=str(data_point.id),
vector=vector,
payload=properties,
@ -201,7 +200,7 @@ class LanceDBAdapter(VectorDBInterface):
.execute(lance_data_points)
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
collection = await self.get_collection(collection_name)
if len(data_point_ids) == 1:
@ -221,12 +220,12 @@ class LanceDBAdapter(VectorDBInterface):
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -264,9 +263,9 @@ class LanceDBAdapter(VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
@ -274,40 +273,44 @@ class LanceDBAdapter(VectorDBInterface):
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
limit=limit or 15,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
collection = await self.get_collection(collection_name)
# Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids:
await collection.delete(f"id = '{data_point_id}'")
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection(
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
)
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
) -> None:
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=str(data_point.id),
text=getattr(data_point, data_point.metadata["index_fields"][0]),
id=data_point.id,
text=getattr(
data_point,
data_point.metadata["index_fields"][0]
),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
],
)
async def prune(self):
async def prune(self) -> None:
connection = await self.get_connection()
collection_names = await connection.table_names()
@ -316,12 +319,15 @@ class LanceDBAdapter(VectorDBInterface):
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)
if self.url.startswith("/"):
if self.url and self.url.startswith("/"):
db_dir_path = path.dirname(self.url)
db_file_name = path.basename(self.url)
await get_file_storage(db_dir_path).remove_all(db_file_name)
def get_data_point_schema(self, model_type: BaseModel):
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
if model_type is None:
return DataPoint
related_models_fields = []
for field_name, field_config in model_type.model_fields.items():