temp
This commit is contained in:
parent
e87b77fda6
commit
5f4c06efd1
1 changed files with 12 additions and 13 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
from os import path
|
from os import path
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
|
|
@ -43,7 +44,7 @@ class IndexSchema(DataPoint):
|
||||||
to include 'text'.
|
to include 'text'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: str
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||||
|
|
@ -135,9 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
- payload: Additional data or metadata associated with the data point.
|
- payload: Additional data or metadata associated with the data point.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: str
|
||||||
vector: Vector[vector_size]
|
vector: Vector(vector_size) # type: ignore
|
||||||
payload: Dict[str, Any]
|
payload: str # JSON string for LanceDB compatibility
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
async with self.VECTOR_DB_LOCK:
|
async with self.VECTOR_DB_LOCK:
|
||||||
|
|
@ -173,11 +174,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||||
)
|
)
|
||||||
|
|
||||||
IdType = TypeVar("IdType")
|
|
||||||
PayloadSchema = TypeVar("PayloadSchema")
|
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
|
class LanceDataPoint(LanceModel):
|
||||||
"""
|
"""
|
||||||
Represents a data point in the Lance model with an ID, vector, and payload.
|
Represents a data point in the Lance model with an ID, vector, and payload.
|
||||||
|
|
||||||
|
|
@ -186,9 +185,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
to the Lance data structure.
|
to the Lance data structure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: IdType
|
id: str
|
||||||
vector: Vector[vector_size]
|
vector: Vector(vector_size) # type: ignore
|
||||||
payload: PayloadSchema
|
payload: str # JSON string for LanceDB compatibility
|
||||||
|
|
||||||
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
|
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
|
||||||
properties = get_own_properties(data_point)
|
properties = get_own_properties(data_point)
|
||||||
|
|
@ -224,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=parse_id(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
payload=result["payload"],
|
payload=json.loads(result["payload"]),
|
||||||
score=0,
|
score=0,
|
||||||
)
|
)
|
||||||
for result in results.to_dict("index").values()
|
for result in results.to_dict("index").values()
|
||||||
|
|
@ -266,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=parse_id(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
payload=result["payload"],
|
payload=json.loads(result["payload"]),
|
||||||
score=normalized_values[value_index],
|
score=normalized_values[value_index],
|
||||||
)
|
)
|
||||||
for value_index, result in enumerate(result_values)
|
for value_index, result in enumerate(result_values)
|
||||||
|
|
@ -312,7 +311,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
f"{index_name}_{index_property_name}",
|
f"{index_name}_{index_property_name}",
|
||||||
[
|
[
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id=data_point.id,
|
id=str(data_point.id),
|
||||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue