From a5e69dc8a7df84b60bf9e74ba37dca8b0d8290a0 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Sun, 7 Sep 2025 13:04:04 -0400 Subject: [PATCH] add group_id routing --- graphiti_core/driver/driver.py | 17 +++++++++++------ graphiti_core/edges.py | 3 ++- graphiti_core/nodes.py | 7 ++++--- graphiti_core/utils/bulk_utils.py | 6 +++--- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index dda3506f..d5f61913 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -231,19 +231,24 @@ class GraphDriver(ABC): return resp return {} + from opensearchpy import helpers + def save_to_aoss(self, name: str, data: list[dict]) -> int: for index in aoss_indices: if name.lower() == index['index_name']: to_index = [] for d in data: - item = {'_index': name} + item = { + '_index': name, + '_routing': d.get('group_id'), # shard routing + } for p in index['body']['mappings']['properties']: - item[p] = d[p] + if p in d: # protect against missing fields + item[p] = d[p] to_index.append(item) + success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True) - if failed > 0: - return success - else: - return 0 + + return success if failed == 0 else success return 0 diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 9d77e4d3..fede9b27 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -262,6 +262,7 @@ class EntityEdge(Edge): 'size': 1, }, index='entity_edges', + routing=self.group_id, ) if resp['hits']['hits']: @@ -313,7 +314,7 @@ class EntityEdge(Edge): edge_data.update(self.attributes or {}) if driver.aoss_client: - driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue + driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_edge_save_query(driver.provider), diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index b9d27197..6e83acd0 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -287,7 +287,7 @@ class EpisodicNode(Node): if driver.aoss_client: driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue - 'episode_content', + 'episodes', [episode_args], ) @@ -432,6 +432,7 @@ class EntityNode(Node): 'size': 1, }, index='entities', + routing=self.group_id, ) if resp['hits']['hits']: @@ -478,7 +479,7 @@ class EntityNode(Node): labels = ':'.join(self.labels + ['Entity']) if driver.aoss_client: - driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue + driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels), @@ -577,7 +578,7 @@ class CommunityNode(Node): async def save(self, driver: GraphDriver): if driver.provider == GraphProvider.NEPTUNE: driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue - 'community_name', + 'communities', [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}], ) result = await driver.execute_query( diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 57bbdd45..bc771c43 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -195,9 +195,9 @@ async def add_nodes_and_edges_bulk_tx( await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges) if driver.aoss_client: - driver.save_to_aoss('episode_content', episodes) - driver.save_to_aoss('node_name_and_summary', nodes) - driver.save_to_aoss('edge_name_and_summary', edges) + driver.save_to_aoss('episodes', episodes) + driver.save_to_aoss('entities', nodes) + driver.save_to_aoss('entity_edges', edges) async def extract_nodes_and_edges_bulk(