diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 1d7fef6f4..560b859fa 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -1,4 +1,5 @@ import asyncio +import json from os import path from uuid import UUID import lancedb @@ -43,7 +44,7 @@ class IndexSchema(DataPoint): to include 'text'. """ - id: UUID + id: str text: str metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"} @@ -135,9 +136,9 @@ class LanceDBAdapter(VectorDBInterface): - payload: Additional data or metadata associated with the data point. """ - id: UUID - vector: Vector[vector_size] - payload: Dict[str, Any] + id: str + vector: Vector(vector_size) # type: ignore + payload: str # JSON string for LanceDB compatibility if not await self.has_collection(collection_name): async with self.VECTOR_DB_LOCK: @@ -173,11 +174,9 @@ class LanceDBAdapter(VectorDBInterface): [DataPoint.get_embeddable_data(data_point) for data_point in data_points] ) - IdType = TypeVar("IdType") - PayloadSchema = TypeVar("PayloadSchema") vector_size = self.embedding_engine.get_vector_size() - class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]): + class LanceDataPoint(LanceModel): """ Represents a data point in the Lance model with an ID, vector, and payload. @@ -186,9 +185,9 @@ class LanceDBAdapter(VectorDBInterface): to the Lance data structure. """ - id: IdType - vector: Vector[vector_size] - payload: PayloadSchema + id: str + vector: Vector(vector_size) # type: ignore + payload: str # JSON string for LanceDB compatibility def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any: properties = get_own_properties(data_point) @@ -224,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface): return [ ScoredResult( id=parse_id(result["id"]), - payload=result["payload"], + payload=json.loads(result["payload"]), score=0, ) for result in results.to_dict("index").values() @@ -266,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface): return [ ScoredResult( id=parse_id(result["id"]), - payload=result["payload"], + payload=json.loads(result["payload"]), score=normalized_values[value_index], ) for value_index, result in enumerate(result_values) @@ -312,7 +311,7 @@ class LanceDBAdapter(VectorDBInterface): f"{index_name}_{index_property_name}", [ IndexSchema( - id=data_point.id, + id=str(data_point.id), text=getattr(data_point, data_point.metadata["index_fields"][0]), ) for data_point in data_points