use async function
This commit is contained in:
parent
7d06888de2
commit
bdc6fbb540
6 changed files with 15 additions and 15 deletions
|
|
@ -39,8 +39,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_SIZE = 10
|
DEFAULT_SIZE = 10
|
||||||
|
|
||||||
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities')
|
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
||||||
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes')
|
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
||||||
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
||||||
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
||||||
|
|
||||||
|
|
@ -279,7 +279,7 @@ class GraphDriver(ABC):
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Index '{index_name}' does not exist")
|
logger.warning(f"Index '{index_name}' does not exist")
|
||||||
|
|
||||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
client = self.aoss_client
|
client = self.aoss_client
|
||||||
if not client or not helpers:
|
if not client or not helpers:
|
||||||
logger.warning('No OpenSearch client found')
|
logger.warning('No OpenSearch client found')
|
||||||
|
|
|
||||||
|
|
@ -325,7 +325,7 @@ class EntityEdge(Edge):
|
||||||
edge_data.update(self.attributes or {})
|
edge_data.update(self.attributes or {})
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_edge_save_query(driver.provider),
|
get_entity_edge_save_query(driver.provider),
|
||||||
|
|
|
||||||
|
|
@ -1045,7 +1045,7 @@ class Graphiti:
|
||||||
updated_edge.fact,
|
updated_edge.fact,
|
||||||
group_ids=[updated_edge.group_id],
|
group_ids=[updated_edge.group_id],
|
||||||
config=EDGE_HYBRID_SEARCH_RRF,
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]),
|
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
||||||
)
|
)
|
||||||
).edges
|
).edges
|
||||||
existing_edges = (
|
existing_edges = (
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ class Node(BaseModel, ABC):
|
||||||
async def delete(self, driver: GraphDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
match driver.provider:
|
match driver.provider:
|
||||||
case GraphProvider.NEO4J:
|
case GraphProvider.NEO4J:
|
||||||
result = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-()
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-()
|
||||||
WITH collect(r.uuid) AS edge_uuids, n
|
WITH collect(r.uuid) AS edge_uuids, n
|
||||||
|
|
@ -112,8 +112,8 @@ class Node(BaseModel, ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
edge_uuids: list[str] = []
|
edge_uuids: list[str] = []
|
||||||
if result and result[0].get('edge_uuids'):
|
if records and records.get('edge_uuids'):
|
||||||
edge_uuids = result[0]['edge_uuids']
|
edge_uuids = records['edge_uuids']
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
# Delete the node from OpenSearch indices
|
# Delete the node from OpenSearch indices
|
||||||
|
|
@ -374,7 +374,7 @@ class EpisodicNode(Node):
|
||||||
}
|
}
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||||
'episodes',
|
'episodes',
|
||||||
[episode_args],
|
[episode_args],
|
||||||
)
|
)
|
||||||
|
|
@ -567,7 +567,7 @@ class EntityNode(Node):
|
||||||
labels = ':'.join(self.labels + ['Entity'])
|
labels = ':'.join(self.labels + ['Entity'])
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||||
|
|
@ -665,7 +665,7 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
async def save(self, driver: GraphDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||||
'communities',
|
'communities',
|
||||||
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -210,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
)
|
)
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||||
driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||||||
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
|
|
|
||||||
|
|
@ -274,7 +274,7 @@ async def resolve_extracted_edges(
|
||||||
extracted_edge.fact,
|
extracted_edge.fact,
|
||||||
group_ids=[extracted_edge.group_id],
|
group_ids=[extracted_edge.group_id],
|
||||||
config=EDGE_HYBRID_SEARCH_RRF,
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
search_filter=SearchFilters(uuids=valid_uuids),
|
search_filter=SearchFilters(edge_uuids=valid_uuids),
|
||||||
)
|
)
|
||||||
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
|
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue