mypy: fix LanceDBAdapter mypy errors
This commit is contained in:
parent
4ae41fede3
commit
eebca89855
1 changed files with 42 additions and 36 deletions
|
|
@ -1,12 +1,14 @@
|
|||
import asyncio
|
||||
from os import path
|
||||
from uuid import UUID
|
||||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.DataPoint import MetaData
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
|
|
@ -30,21 +32,21 @@ class IndexSchema(DataPoint):
|
|||
to include 'text'.
|
||||
"""
|
||||
|
||||
id: str
|
||||
id: UUID
|
||||
text: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
|
||||
|
||||
class LanceDBAdapter(VectorDBInterface):
|
||||
name = "LanceDB"
|
||||
url: str
|
||||
api_key: str
|
||||
url: Optional[str]
|
||||
api_key: Optional[str]
|
||||
connection: lancedb.AsyncConnection = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str],
|
||||
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
|
||||
api_key: Optional[str],
|
||||
embedding_engine: EmbeddingEngine,
|
||||
):
|
||||
|
|
@ -53,7 +55,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self.embedding_engine = embedding_engine
|
||||
self.VECTOR_DB_LOCK = asyncio.Lock()
|
||||
|
||||
async def get_connection(self):
|
||||
async def get_connection(self) -> lancedb.AsyncConnection:
|
||||
"""
|
||||
Establishes and returns a connection to the LanceDB.
|
||||
|
||||
|
|
@ -107,12 +109,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_names = await connection.table_names()
|
||||
return collection_name in collection_names
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
payload_schema = self.get_data_point_schema(payload_schema)
|
||||
data_point_types = get_type_hints(payload_schema)
|
||||
|
||||
class LanceDataPoint(LanceModel):
|
||||
"""
|
||||
Represents a data point in the Lance model with an ID, vector, and associated payload.
|
||||
|
|
@ -123,28 +122,28 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
- payload: Additional data or metadata associated with the data point.
|
||||
"""
|
||||
|
||||
id: data_point_types["id"]
|
||||
vector: Vector(vector_size)
|
||||
payload: payload_schema
|
||||
id: UUID
|
||||
vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic
|
||||
payload: Dict[str, Any]
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
connection = await self.get_connection()
|
||||
return await connection.create_table(
|
||||
await connection.create_table(
|
||||
name=collection_name,
|
||||
schema=LanceDataPoint,
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
async def get_collection(self, collection_name: str):
|
||||
async def get_collection(self, collection_name: str) -> Any:
|
||||
if not await self.has_collection(collection_name):
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
|
||||
connection = await self.get_connection()
|
||||
return await connection.open_table(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||
payload_schema = type(data_points[0])
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
|
|
@ -175,14 +174,14 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
"""
|
||||
|
||||
id: IdType
|
||||
vector: Vector(vector_size)
|
||||
vector: Vector[vector_size]
|
||||
payload: PayloadSchema
|
||||
|
||||
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
|
||||
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
|
||||
properties = get_own_properties(data_point)
|
||||
properties["id"] = str(properties["id"])
|
||||
|
||||
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
|
||||
return LanceDataPoint(
|
||||
id=str(data_point.id),
|
||||
vector=vector,
|
||||
payload=properties,
|
||||
|
|
@ -201,7 +200,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
.execute(lance_data_points)
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if len(data_point_ids) == 1:
|
||||
|
|
@ -221,12 +220,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise MissingQueryParameterError()
|
||||
|
||||
|
|
@ -264,9 +263,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
) -> List[List[ScoredResult]]:
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
|
|
@ -274,40 +273,44 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
limit=limit or 15,
|
||||
with_vector=with_vectors,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
# Delete one at a time to avoid commit conflicts
|
||||
for data_point_id in data_point_ids:
|
||||
await collection.delete(f"id = '{data_point_id}'")
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||
await self.create_collection(
|
||||
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
|
||||
)
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||
) -> None:
|
||||
await self.create_data_points(
|
||||
f"{index_name}_{index_property_name}",
|
||||
[
|
||||
IndexSchema(
|
||||
id=str(data_point.id),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
id=data_point.id,
|
||||
text=getattr(
|
||||
data_point,
|
||||
data_point.metadata["index_fields"][0]
|
||||
),
|
||||
)
|
||||
for data_point in data_points
|
||||
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
||||
],
|
||||
)
|
||||
|
||||
async def prune(self):
|
||||
async def prune(self) -> None:
|
||||
connection = await self.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
|
||||
|
|
@ -316,12 +319,15 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
await collection.delete("id IS NOT NULL")
|
||||
await connection.drop_table(collection_name)
|
||||
|
||||
if self.url.startswith("/"):
|
||||
if self.url and self.url.startswith("/"):
|
||||
db_dir_path = path.dirname(self.url)
|
||||
db_file_name = path.basename(self.url)
|
||||
await get_file_storage(db_dir_path).remove_all(db_file_name)
|
||||
|
||||
def get_data_point_schema(self, model_type: BaseModel):
|
||||
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
|
||||
if model_type is None:
|
||||
return DataPoint
|
||||
|
||||
related_models_fields = []
|
||||
|
||||
for field_name, field_config in model_type.model_fields.items():
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue