use async function

This commit is contained in:
prestonrasmussen 2025-09-12 16:40:01 -04:00
parent 7d06888de2
commit bdc6fbb540
6 changed files with 15 additions and 15 deletions

View file

@ -39,8 +39,8 @@ logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10 DEFAULT_SIZE = 10
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities') ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes') EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities') COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges') ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
@ -279,7 +279,7 @@ class GraphDriver(ABC):
else: else:
logger.warning(f"Index '{index_name}' does not exist") logger.warning(f"Index '{index_name}' does not exist")
def save_to_aoss(self, name: str, data: list[dict]) -> int: async def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client client = self.aoss_client
if not client or not helpers: if not client or not helpers:
logger.warning('No OpenSearch client found') logger.warning('No OpenSearch client found')

View file

@ -325,7 +325,7 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {}) edge_data.update(self.attributes or {})
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_edge_save_query(driver.provider), get_entity_edge_save_query(driver.provider),

View file

@ -1045,7 +1045,7 @@ class Graphiti:
updated_edge.fact, updated_edge.fact,
group_ids=[updated_edge.group_id], group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF, config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]), search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
) )
).edges ).edges
existing_edges = ( existing_edges = (

View file

@ -101,7 +101,7 @@ class Node(BaseModel, ABC):
async def delete(self, driver: GraphDriver): async def delete(self, driver: GraphDriver):
match driver.provider: match driver.provider:
case GraphProvider.NEO4J: case GraphProvider.NEO4J:
result = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-() MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-()
WITH collect(r.uuid) AS edge_uuids, n WITH collect(r.uuid) AS edge_uuids, n
@ -112,8 +112,8 @@ class Node(BaseModel, ABC):
) )
edge_uuids: list[str] = [] edge_uuids: list[str] = []
if result and result[0].get('edge_uuids'): if records and records.get('edge_uuids'):
edge_uuids = result[0]['edge_uuids'] edge_uuids = records['edge_uuids']
if driver.aoss_client: if driver.aoss_client:
# Delete the node from OpenSearch indices # Delete the node from OpenSearch indices
@ -374,7 +374,7 @@ class EpisodicNode(Node):
} }
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episodes', 'episodes',
[episode_args], [episode_args],
) )
@ -567,7 +567,7 @@ class EntityNode(Node):
labels = ':'.join(self.labels + ['Entity']) labels = ':'.join(self.labels + ['Entity'])
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
@ -665,7 +665,7 @@ class CommunityNode(Node):
async def save(self, driver: GraphDriver): async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'communities', 'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}], [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
) )

View file

@ -210,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
) )
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
async def extract_nodes_and_edges_bulk( async def extract_nodes_and_edges_bulk(

View file

@ -274,7 +274,7 @@ async def resolve_extracted_edges(
extracted_edge.fact, extracted_edge.fact,
group_ids=[extracted_edge.group_id], group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF, config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids), search_filter=SearchFilters(edge_uuids=valid_uuids),
) )
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True) for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
] ]