use async function

This commit is contained in:
prestonrasmussen 2025-09-12 16:40:01 -04:00
parent 7d06888de2
commit bdc6fbb540
6 changed files with 15 additions and 15 deletions

View file

@ -39,8 +39,8 @@ logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities')
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes')
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
@ -279,7 +279,7 @@ class GraphDriver(ABC):
else:
logger.warning(f"Index '{index_name}' does not exist")
def save_to_aoss(self, name: str, data: list[dict]) -> int:
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client
if not client or not helpers:
logger.warning('No OpenSearch client found')

View file

@ -325,7 +325,7 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {})
if driver.aoss_client:
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),

View file

@ -1045,7 +1045,7 @@ class Graphiti:
updated_edge.fact,
group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]),
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
)
).edges
existing_edges = (

View file

@ -101,7 +101,7 @@ class Node(BaseModel, ABC):
async def delete(self, driver: GraphDriver):
match driver.provider:
case GraphProvider.NEO4J:
result = await driver.execute_query(
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-()
WITH collect(r.uuid) AS edge_uuids, n
@ -112,8 +112,8 @@ class Node(BaseModel, ABC):
)
edge_uuids: list[str] = []
if result and result[0].get('edge_uuids'):
edge_uuids = result[0]['edge_uuids']
if records and records.get('edge_uuids'):
edge_uuids = records['edge_uuids']
if driver.aoss_client:
# Delete the node from OpenSearch indices
@ -374,7 +374,7 @@ class EpisodicNode(Node):
}
if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episodes',
[episode_args],
)
@ -567,7 +567,7 @@ class EntityNode(Node):
labels = ':'.join(self.labels + ['Entity'])
if driver.aoss_client:
driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
@ -665,7 +665,7 @@ class CommunityNode(Node):
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)

View file

@ -210,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
)
if driver.aoss_client:
driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
async def extract_nodes_and_edges_bulk(

View file

@ -274,7 +274,7 @@ async def resolve_extracted_edges(
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids),
search_filter=SearchFilters(edge_uuids=valid_uuids),
)
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
]