From c5ec136b1d24386ea5c34822010cb46a0649d7af Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 2 Oct 2025 14:49:24 -0400 Subject: [PATCH] update --- graphiti_core/edges.py | 2 +- graphiti_core/nodes.py | 46 +++++++++++++++++++----------------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index c2744e05..bceaaacb 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -649,5 +649,5 @@ async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[En # filter out falsey values from edges filtered_edges = [edge for edge in edges if edge.fact] fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges]) - for edge, fact_embedding in zip(edges, fact_embeddings, strict=True): + for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True): edge.fact_embedding = fact_embedding diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 0da67f09..bd1c1a92 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -96,8 +96,7 @@ class Node(BaseModel, ABC): created_at: datetime = Field(default_factory=lambda: utc_now()) @abstractmethod - async def save(self, driver: GraphDriver): - ... + async def save(self, driver: GraphDriver): ... async def delete(self, driver: GraphDriver): match driver.provider: @@ -345,12 +344,10 @@ class Node(BaseModel, ABC): await driver.aoss_client.bulk(body=actions) @classmethod - async def get_by_uuid(cls, driver: GraphDriver, uuid: str): - ... + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... @classmethod - async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): - ... + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ... class EpisodicNode(Node): @@ -438,11 +435,11 @@ class EpisodicNode(Node): @classmethod async def get_by_group_ids( - cls, - driver: GraphDriver, - group_ids: list[str], - limit: int | None = None, - uuid_cursor: str | None = None, + cls, + driver: GraphDriver, + group_ids: list[str], + limit: int | None = None, + uuid_cursor: str | None = None, ): cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' @@ -572,8 +569,7 @@ class EntityNode(Node): labels = ':'.join(self.labels + ['Entity']) if driver.aoss_client: - await driver.save_to_aoss(ENTITY_INDEX_NAME, - [entity_data]) # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), @@ -622,12 +618,12 @@ class EntityNode(Node): @classmethod async def get_by_group_ids( - cls, - driver: GraphDriver, - group_ids: list[str], - limit: int | None = None, - uuid_cursor: str | None = None, - with_embeddings: bool = False, + cls, + driver: GraphDriver, + group_ids: list[str], + limit: int | None = None, + uuid_cursor: str | None = None, + with_embeddings: bool = False, ): cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' @@ -767,11 +763,11 @@ class CommunityNode(Node): @classmethod async def get_by_group_ids( - cls, - driver: GraphDriver, - group_ids: list[str], - limit: int | None = None, - uuid_cursor: str | None = None, + cls, + driver: GraphDriver, + group_ids: list[str], + limit: int | None = None, + uuid_cursor: str | None = None, ): cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' @@ -878,5 +874,5 @@ async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[En # filter out falsey values from nodes filtered_nodes = [node for node in nodes if node.name] name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes]) - for node, name_embedding in zip(nodes, name_embeddings, strict=True): + for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True): node.name_embedding = name_embedding