Fixes LanceDB datapoint add
This commit is contained in:
parent
3e7df33c15
commit
e988a67466
2 changed files with 12 additions and 7 deletions
|
|
@ -27,9 +27,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
max_connection_lifetime = 120
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.driver.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncSession:
|
||||
async with self.driver.session() as session:
|
||||
|
|
|
|||
|
|
@ -112,10 +112,18 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
for (data_point_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
|
||||
await collection.merge_insert("id") \
|
||||
.when_matched_update_all() \
|
||||
.when_not_matched_insert_all() \
|
||||
.execute(lance_data_points)
|
||||
# TODO: This enables us to work with pydantic version but shouldn't
|
||||
# stay like this, existing rows should be updated
|
||||
|
||||
await collection.delete("id IS NOT NULL")
|
||||
|
||||
original_size = await collection.count_rows()
|
||||
await collection.add(lance_data_points)
|
||||
new_size = await collection.count_rows()
|
||||
|
||||
if new_size <= original_size:
|
||||
raise ValueError(
|
||||
"LanceDB create_datapoints error: data points did not get added.")
|
||||
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue