use async function
This commit is contained in:
parent
7d06888de2
commit
bdc6fbb540
6 changed files with 15 additions and 15 deletions
|
|
@ -39,8 +39,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities')
|
||||
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes')
|
||||
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
||||
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
||||
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
||||
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
||||
|
||||
|
|
@ -279,7 +279,7 @@ class GraphDriver(ABC):
|
|||
else:
|
||||
logger.warning(f"Index '{index_name}' does not exist")
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
client = self.aoss_client
|
||||
if not client or not helpers:
|
||||
logger.warning('No OpenSearch client found')
|
||||
|
|
|
|||
|
|
@ -325,7 +325,7 @@ class EntityEdge(Edge):
|
|||
edge_data.update(self.attributes or {})
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
|
|
|
|||
|
|
@ -1045,7 +1045,7 @@ class Graphiti:
|
|||
updated_edge.fact,
|
||||
group_ids=[updated_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]),
|
||||
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
||||
)
|
||||
).edges
|
||||
existing_edges = (
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class Node(BaseModel, ABC):
|
|||
async def delete(self, driver: GraphDriver):
|
||||
match driver.provider:
|
||||
case GraphProvider.NEO4J:
|
||||
result = await driver.execute_query(
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-()
|
||||
WITH collect(r.uuid) AS edge_uuids, n
|
||||
|
|
@ -112,8 +112,8 @@ class Node(BaseModel, ABC):
|
|||
)
|
||||
|
||||
edge_uuids: list[str] = []
|
||||
if result and result[0].get('edge_uuids'):
|
||||
edge_uuids = result[0]['edge_uuids']
|
||||
if records and records.get('edge_uuids'):
|
||||
edge_uuids = records['edge_uuids']
|
||||
|
||||
if driver.aoss_client:
|
||||
# Delete the node from OpenSearch indices
|
||||
|
|
@ -374,7 +374,7 @@ class EpisodicNode(Node):
|
|||
}
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episodes',
|
||||
[episode_args],
|
||||
)
|
||||
|
|
@ -567,7 +567,7 @@ class EntityNode(Node):
|
|||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||
|
|
@ -665,7 +665,7 @@ class CommunityNode(Node):
|
|||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'communities',
|
||||
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -210,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||
driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||||
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
|
|
|
|||
|
|
@ -274,7 +274,7 @@ async def resolve_extracted_edges(
|
|||
extracted_edge.fact,
|
||||
group_ids=[extracted_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(uuids=valid_uuids),
|
||||
search_filter=SearchFilters(edge_uuids=valid_uuids),
|
||||
)
|
||||
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue