add group_id routing
This commit is contained in:
parent
58c1f7e395
commit
a5e69dc8a7
4 changed files with 20 additions and 13 deletions
|
|
@ -231,19 +231,24 @@ class GraphDriver(ABC):
|
|||
return resp
|
||||
return {}
|
||||
|
||||
from opensearchpy import helpers
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {'_index': name}
|
||||
item = {
|
||||
'_index': name,
|
||||
'_routing': d.get('group_id'), # shard routing
|
||||
}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
item[p] = d[p]
|
||||
if p in d: # protect against missing fields
|
||||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
|
||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||
if failed > 0:
|
||||
return success
|
||||
else:
|
||||
return 0
|
||||
|
||||
return success if failed == 0 else success
|
||||
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -262,6 +262,7 @@ class EntityEdge(Edge):
|
|||
'size': 1,
|
||||
},
|
||||
index='entity_edges',
|
||||
routing=self.group_id,
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
|
|
@ -313,7 +314,7 @@ class EntityEdge(Edge):
|
|||
edge_data.update(self.attributes or {})
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
|
|
|
|||
|
|
@ -287,7 +287,7 @@ class EpisodicNode(Node):
|
|||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episode_content',
|
||||
'episodes',
|
||||
[episode_args],
|
||||
)
|
||||
|
||||
|
|
@ -432,6 +432,7 @@ class EntityNode(Node):
|
|||
'size': 1,
|
||||
},
|
||||
index='entities',
|
||||
routing=self.group_id,
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
|
|
@ -478,7 +479,7 @@ class EntityNode(Node):
|
|||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
|
|
@ -577,7 +578,7 @@ class CommunityNode(Node):
|
|||
async def save(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'community_name',
|
||||
'communities',
|
||||
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||
)
|
||||
result = await driver.execute_query(
|
||||
|
|
|
|||
|
|
@ -195,9 +195,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('episode_content', episodes)
|
||||
driver.save_to_aoss('node_name_and_summary', nodes)
|
||||
driver.save_to_aoss('edge_name_and_summary', edges)
|
||||
driver.save_to_aoss('episodes', episodes)
|
||||
driver.save_to_aoss('entities', nodes)
|
||||
driver.save_to_aoss('entity_edges', edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue