From e988a67466cb613494e2299548f76db057c9070e Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 11 Nov 2024 19:28:17 +0100 Subject: [PATCH] Fixes LanceDB datapoint add --- .../databases/graph/neo4j_driver/adapter.py | 3 --- .../databases/vector/lancedb/LanceDBAdapter.py | 16 ++++++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 26bbb5819..1121a24d5 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -27,9 +27,6 @@ class Neo4jAdapter(GraphDBInterface): max_connection_lifetime = 120 ) - async def close(self) -> None: - await self.driver.close() - @asynccontextmanager async def get_session(self) -> AsyncSession: async with self.driver.session() as session: diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index d883a29e7..96f026b4f 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -112,10 +112,18 @@ class LanceDBAdapter(VectorDBInterface): for (data_point_index, data_point) in enumerate(data_points) ] - await collection.merge_insert("id") \ - .when_matched_update_all() \ - .when_not_matched_insert_all() \ - .execute(lance_data_points) + # TODO: This enables us to work with pydantic version but shouldn't + # stay like this, existing rows should be updated + + await collection.delete("id IS NOT NULL") + + original_size = await collection.count_rows() + await collection.add(lance_data_points) + new_size = await collection.count_rows() + + if new_size <= original_size: + raise ValueError( + "LanceDB create_datapoints error: data points did not get added.") async def retrieve(self, collection_name: str, data_point_ids: list[str]):