This commit is contained in:
Daulet Amirkhanov 2025-09-11 14:02:27 +01:00
parent e87b77fda6
commit 5f4c06efd1

View file

@ -1,4 +1,5 @@
import asyncio
import json
from os import path
from uuid import UUID
import lancedb
@ -43,7 +44,7 @@ class IndexSchema(DataPoint):
to include 'text'.
"""
id: UUID
id: str
text: str
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
@ -135,9 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
- payload: Additional data or metadata associated with the data point.
"""
id: UUID
vector: Vector[vector_size]
payload: Dict[str, Any]
id: str
vector: Vector(vector_size) # type: ignore
payload: str # JSON string for LanceDB compatibility
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
@ -173,11 +174,9 @@ class LanceDBAdapter(VectorDBInterface):
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
IdType = TypeVar("IdType")
PayloadSchema = TypeVar("PayloadSchema")
vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
class LanceDataPoint(LanceModel):
"""
Represents a data point in the Lance model with an ID, vector, and payload.
@ -186,9 +185,9 @@ class LanceDBAdapter(VectorDBInterface):
to the Lance data structure.
"""
id: IdType
vector: Vector[vector_size]
payload: PayloadSchema
id: str
vector: Vector(vector_size) # type: ignore
payload: str # JSON string for LanceDB compatibility
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
properties = get_own_properties(data_point)
@ -224,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
payload=json.loads(result["payload"]),
score=0,
)
for result in results.to_dict("index").values()
@ -266,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
payload=json.loads(result["payload"]),
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
@ -312,7 +311,7 @@ class LanceDBAdapter(VectorDBInterface):
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
id=str(data_point.id),
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points