diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 3766dba7..86cd8311 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -26,7 +26,7 @@ from pydantic import BaseModel from typing_extensions import Any from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge -from graphiti_core.helpers import semaphore_gather +from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.models.edges.edge_db_queries import ( ENTITY_EDGE_SAVE_BULK, @@ -95,7 +95,7 @@ async def add_nodes_and_edges_bulk( entity_nodes: list[EntityNode], entity_edges: list[EntityEdge], ): - async with driver.session() as session: + async with driver.session(database=DEFAULT_DATABASE) as session: await session.execute_write( add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges ) diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 226d8f41..7ea2c205 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -42,7 +42,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo driver.execute_query( """DROP INDEX $name""", name=name, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) for name in index_names ] @@ -87,7 +87,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo *[ driver.execute_query( query, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) for query in index_queries ] @@ -95,7 +95,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None): - async with driver.session() as session: + async with driver.session(database=DEFAULT_DATABASE) as session: async def delete_all(tx): await tx.run('MATCH (n) DETACH DELETE n') @@ -150,7 +150,7 @@ async def retrieve_episodes( reference_time=reference_time, num_episodes=last_n, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) episodes = [ EpisodicNode(