From d0b1b2e5dbbcd7bf113fffcfa6d5ef9aa415eac4 Mon Sep 17 00:00:00 2001 From: neonconsultingllc <146017103+neonconsultingllc@users.noreply.github.com> Date: Fri, 18 Apr 2025 12:06:31 -0500 Subject: [PATCH] Fix for using non default neo4j database (#329) Pass database_ correctly to driver.session to fix using non default database --- graphiti_core/utils/bulk_utils.py | 4 ++-- graphiti_core/utils/maintenance/graph_data_operations.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 3766dba7..86cd8311 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -26,7 +26,7 @@ from pydantic import BaseModel from typing_extensions import Any from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge -from graphiti_core.helpers import semaphore_gather +from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.models.edges.edge_db_queries import ( ENTITY_EDGE_SAVE_BULK, @@ -95,7 +95,7 @@ async def add_nodes_and_edges_bulk( entity_nodes: list[EntityNode], entity_edges: list[EntityEdge], ): - async with driver.session() as session: + async with driver.session(database=DEFAULT_DATABASE) as session: await session.execute_write( add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges ) diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 226d8f41..7ea2c205 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -42,7 +42,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo driver.execute_query( """DROP INDEX $name""", name=name, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) for name in index_names ] @@ -87,7 +87,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo *[ driver.execute_query( query, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) for query in index_queries ] @@ -95,7 +95,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None): - async with driver.session() as session: + async with driver.session(database=DEFAULT_DATABASE) as session: async def delete_all(tx): await tx.run('MATCH (n) DETACH DELETE n') @@ -150,7 +150,7 @@ async def retrieve_episodes( reference_time=reference_time, num_episodes=last_n, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) episodes = [ EpisodicNode(