This commit is contained in:
Daulet Amirkhanov 2025-09-11 14:02:27 +01:00
parent e87b77fda6
commit 5f4c06efd1

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import json
from os import path from os import path
from uuid import UUID from uuid import UUID
import lancedb import lancedb
@ -43,7 +44,7 @@ class IndexSchema(DataPoint):
to include 'text'. to include 'text'.
""" """
id: UUID id: str
text: str text: str
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"} metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
@ -135,9 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
- payload: Additional data or metadata associated with the data point. - payload: Additional data or metadata associated with the data point.
""" """
id: UUID id: str
vector: Vector[vector_size] vector: Vector(vector_size) # type: ignore
payload: Dict[str, Any] payload: str # JSON string for LanceDB compatibility
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK: async with self.VECTOR_DB_LOCK:
@ -173,11 +174,9 @@ class LanceDBAdapter(VectorDBInterface):
[DataPoint.get_embeddable_data(data_point) for data_point in data_points] [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
) )
IdType = TypeVar("IdType")
PayloadSchema = TypeVar("PayloadSchema")
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]): class LanceDataPoint(LanceModel):
""" """
Represents a data point in the Lance model with an ID, vector, and payload. Represents a data point in the Lance model with an ID, vector, and payload.
@ -186,9 +185,9 @@ class LanceDBAdapter(VectorDBInterface):
to the Lance data structure. to the Lance data structure.
""" """
id: IdType id: str
vector: Vector[vector_size] vector: Vector(vector_size) # type: ignore
payload: PayloadSchema payload: str # JSON string for LanceDB compatibility
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any: def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
properties = get_own_properties(data_point) properties = get_own_properties(data_point)
@ -224,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id=parse_id(result["id"]), id=parse_id(result["id"]),
payload=result["payload"], payload=json.loads(result["payload"]),
score=0, score=0,
) )
for result in results.to_dict("index").values() for result in results.to_dict("index").values()
@ -266,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id=parse_id(result["id"]), id=parse_id(result["id"]),
payload=result["payload"], payload=json.loads(result["payload"]),
score=normalized_values[value_index], score=normalized_values[value_index],
) )
for value_index, result in enumerate(result_values) for value_index, result in enumerate(result_values)
@ -312,7 +311,7 @@ class LanceDBAdapter(VectorDBInterface):
f"{index_name}_{index_property_name}", f"{index_name}_{index_property_name}",
[ [
IndexSchema( IndexSchema(
id=data_point.id, id=str(data_point.id),
text=getattr(data_point, data_point.metadata["index_fields"][0]), text=getattr(data_point, data_point.metadata["index_fields"][0]),
) )
for data_point in data_points for data_point in data_points