add group_id routing

This commit is contained in:
prestonrasmussen 2025-09-07 13:04:04 -04:00
parent 58c1f7e395
commit a5e69dc8a7
4 changed files with 20 additions and 13 deletions

View file

@ -231,19 +231,24 @@ class GraphDriver(ABC):
return resp
return {}
from opensearchpy import helpers
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {'_index': name}
item = {
'_index': name,
'_routing': d.get('group_id'), # shard routing
}
for p in index['body']['mappings']['properties']:
item[p] = d[p]
if p in d: # protect against missing fields
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
if failed > 0:
return success
else:
return 0
return success if failed == 0 else success
return 0

View file

@ -262,6 +262,7 @@ class EntityEdge(Edge):
'size': 1,
},
index='entity_edges',
routing=self.group_id,
)
if resp['hits']['hits']:
@ -313,7 +314,7 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {})
if driver.aoss_client:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),

View file

@ -287,7 +287,7 @@ class EpisodicNode(Node):
if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episode_content',
'episodes',
[episode_args],
)
@ -432,6 +432,7 @@ class EntityNode(Node):
'size': 1,
},
index='entities',
routing=self.group_id,
)
if resp['hits']['hits']:
@ -478,7 +479,7 @@ class EntityNode(Node):
labels = ':'.join(self.labels + ['Entity'])
if driver.aoss_client:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),
@ -577,7 +578,7 @@ class CommunityNode(Node):
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'community_name',
'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)
result = await driver.execute_query(

View file

@ -195,9 +195,9 @@ async def add_nodes_and_edges_bulk_tx(
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
if driver.aoss_client:
driver.save_to_aoss('episode_content', episodes)
driver.save_to_aoss('node_name_and_summary', nodes)
driver.save_to_aoss('edge_name_and_summary', edges)
driver.save_to_aoss('episodes', episodes)
driver.save_to_aoss('entities', nodes)
driver.save_to_aoss('entity_edges', edges)
async def extract_nodes_and_edges_bulk(