add group_id routing
This commit is contained in:
parent
58c1f7e395
commit
a5e69dc8a7
4 changed files with 20 additions and 13 deletions
|
|
@ -231,19 +231,24 @@ class GraphDriver(ABC):
|
||||||
return resp
|
return resp
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
from opensearchpy import helpers
|
||||||
|
|
||||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
for index in aoss_indices:
|
for index in aoss_indices:
|
||||||
if name.lower() == index['index_name']:
|
if name.lower() == index['index_name']:
|
||||||
to_index = []
|
to_index = []
|
||||||
for d in data:
|
for d in data:
|
||||||
item = {'_index': name}
|
item = {
|
||||||
|
'_index': name,
|
||||||
|
'_routing': d.get('group_id'), # shard routing
|
||||||
|
}
|
||||||
for p in index['body']['mappings']['properties']:
|
for p in index['body']['mappings']['properties']:
|
||||||
|
if p in d: # protect against missing fields
|
||||||
item[p] = d[p]
|
item[p] = d[p]
|
||||||
to_index.append(item)
|
to_index.append(item)
|
||||||
|
|
||||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||||
if failed > 0:
|
|
||||||
return success
|
return success if failed == 0 else success
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
|
||||||
|
|
@ -262,6 +262,7 @@ class EntityEdge(Edge):
|
||||||
'size': 1,
|
'size': 1,
|
||||||
},
|
},
|
||||||
index='entity_edges',
|
index='entity_edges',
|
||||||
|
routing=self.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp['hits']['hits']:
|
if resp['hits']['hits']:
|
||||||
|
|
@ -313,7 +314,7 @@ class EntityEdge(Edge):
|
||||||
edge_data.update(self.attributes or {})
|
edge_data.update(self.attributes or {})
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_edge_save_query(driver.provider),
|
get_entity_edge_save_query(driver.provider),
|
||||||
|
|
|
||||||
|
|
@ -287,7 +287,7 @@ class EpisodicNode(Node):
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||||
'episode_content',
|
'episodes',
|
||||||
[episode_args],
|
[episode_args],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -432,6 +432,7 @@ class EntityNode(Node):
|
||||||
'size': 1,
|
'size': 1,
|
||||||
},
|
},
|
||||||
index='entities',
|
index='entities',
|
||||||
|
routing=self.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp['hits']['hits']:
|
if resp['hits']['hits']:
|
||||||
|
|
@ -478,7 +479,7 @@ class EntityNode(Node):
|
||||||
labels = ':'.join(self.labels + ['Entity'])
|
labels = ':'.join(self.labels + ['Entity'])
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_node_save_query(driver.provider, labels),
|
get_entity_node_save_query(driver.provider, labels),
|
||||||
|
|
@ -577,7 +578,7 @@ class CommunityNode(Node):
|
||||||
async def save(self, driver: GraphDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||||
'community_name',
|
'communities',
|
||||||
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||||
)
|
)
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
|
|
|
||||||
|
|
@ -195,9 +195,9 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
|
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss('episode_content', episodes)
|
driver.save_to_aoss('episodes', episodes)
|
||||||
driver.save_to_aoss('node_name_and_summary', nodes)
|
driver.save_to_aoss('entities', nodes)
|
||||||
driver.save_to_aoss('edge_name_and_summary', edges)
|
driver.save_to_aoss('entity_edges', edges)
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue