From bdc6fbb540d3e180663f5f8d8a25d44dd1c7fd77 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 12 Sep 2025 16:40:01 -0400 Subject: [PATCH] use async function --- graphiti_core/driver/driver.py | 6 +++--- graphiti_core/edges.py | 2 +- graphiti_core/graphiti.py | 2 +- graphiti_core/nodes.py | 12 ++++++------ graphiti_core/utils/bulk_utils.py | 6 +++--- graphiti_core/utils/maintenance/edge_operations.py | 2 +- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 1f47822c..52858582 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -39,8 +39,8 @@ logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 -ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities') -EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes') +ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities') +EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes') COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities') ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges') @@ -279,7 +279,7 @@ class GraphDriver(ABC): else: logger.warning(f"Index '{index_name}' does not exist") - def save_to_aoss(self, name: str, data: list[dict]) -> int: + async def save_to_aoss(self, name: str, data: list[dict]) -> int: client = self.aoss_client if not client or not helpers: logger.warning('No OpenSearch client found') diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index cda88595..90a49762 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -325,7 +325,7 @@ class EntityEdge(Edge): edge_data.update(self.attributes or {}) if driver.aoss_client: - driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_edge_save_query(driver.provider), diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index f168deda..bce0c326 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -1045,7 +1045,7 @@ class Graphiti: updated_edge.fact, group_ids=[updated_edge.group_id], config=EDGE_HYBRID_SEARCH_RRF, - search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]), + search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]), ) ).edges existing_edges = ( diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index a11d7400..09673c1c 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -101,7 +101,7 @@ class Node(BaseModel, ABC): async def delete(self, driver: GraphDriver): match driver.provider: case GraphProvider.NEO4J: - result = await driver.execute_query( + records, _, _ = await driver.execute_query( """ MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-() WITH collect(r.uuid) AS edge_uuids, n @@ -112,8 +112,8 @@ class Node(BaseModel, ABC): ) edge_uuids: list[str] = [] - if result and result[0].get('edge_uuids'): - edge_uuids = result[0]['edge_uuids'] + if records and records.get('edge_uuids'): + edge_uuids = records['edge_uuids'] if driver.aoss_client: # Delete the node from OpenSearch indices @@ -374,7 +374,7 @@ class EpisodicNode(Node): } if driver.aoss_client: - driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue 'episodes', [episode_args], ) @@ -567,7 +567,7 @@ class EntityNode(Node): labels = ':'.join(self.labels + ['Entity']) if driver.aoss_client: - driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), @@ -665,7 +665,7 @@ class CommunityNode(Node): async def save(self, driver: GraphDriver): if driver.provider == GraphProvider.NEPTUNE: - driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue 'communities', [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}], ) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index e9b0a31d..78397e87 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -210,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx( ) if driver.aoss_client: - driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) - driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) - driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) + await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) + await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) + await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) async def extract_nodes_and_edges_bulk( diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 1a38bd55..72e3dde1 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -274,7 +274,7 @@ async def resolve_extracted_edges( extracted_edge.fact, group_ids=[extracted_edge.group_id], config=EDGE_HYBRID_SEARCH_RRF, - search_filter=SearchFilters(uuids=valid_uuids), + search_filter=SearchFilters(edge_uuids=valid_uuids), ) for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True) ]