From dd0d42855e9347d89bf9ec2a9f308054c69f26bc Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 2 Oct 2025 14:47:46 -0400 Subject: [PATCH] filter out falsey values --- graphiti_core/edges.py | 4 +++- graphiti_core/nodes.py | 48 ++++++++++++++++++++++++------------------ pyproject.toml | 2 +- uv.lock | 4 ++-- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 165dee53..c2744e05 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -646,6 +646,8 @@ def get_community_edge_from_record(record: Any): async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]): if len(edges) == 0: return - fact_embeddings = await embedder.create_batch([edge.fact for edge in edges]) + # filter out falsey values from edges + filtered_edges = [edge for edge in edges if edge.fact] + fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges]) for edge, fact_embedding in zip(edges, fact_embeddings, strict=True): edge.fact_embedding = fact_embedding diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 7fafbe4f..0da67f09 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -96,7 +96,8 @@ class Node(BaseModel, ABC): created_at: datetime = Field(default_factory=lambda: utc_now()) @abstractmethod - async def save(self, driver: GraphDriver): ... + async def save(self, driver: GraphDriver): + ... async def delete(self, driver: GraphDriver): match driver.provider: @@ -344,10 +345,12 @@ class Node(BaseModel, ABC): await driver.aoss_client.bulk(body=actions) @classmethod - async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): + ... @classmethod - async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ... + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): + ... class EpisodicNode(Node): @@ -435,11 +438,11 @@ class EpisodicNode(Node): @classmethod async def get_by_group_ids( - cls, - driver: GraphDriver, - group_ids: list[str], - limit: int | None = None, - uuid_cursor: str | None = None, + cls, + driver: GraphDriver, + group_ids: list[str], + limit: int | None = None, + uuid_cursor: str | None = None, ): cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' @@ -569,7 +572,8 @@ class EntityNode(Node): labels = ':'.join(self.labels + ['Entity']) if driver.aoss_client: - await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss(ENTITY_INDEX_NAME, + [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), @@ -618,12 +622,12 @@ class EntityNode(Node): @classmethod async def get_by_group_ids( - cls, - driver: GraphDriver, - group_ids: list[str], - limit: int | None = None, - uuid_cursor: str | None = None, - with_embeddings: bool = False, + cls, + driver: GraphDriver, + group_ids: list[str], + limit: int | None = None, + uuid_cursor: str | None = None, + with_embeddings: bool = False, ): cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' @@ -763,11 +767,11 @@ class CommunityNode(Node): @classmethod async def get_by_group_ids( - cls, - driver: GraphDriver, - group_ids: list[str], - limit: int | None = None, - uuid_cursor: str | None = None, + cls, + driver: GraphDriver, + group_ids: list[str], + limit: int | None = None, + uuid_cursor: str | None = None, ): cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else '' limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' @@ -871,6 +875,8 @@ async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[En if not nodes: # Handle empty list case return - name_embeddings = await embedder.create_batch([node.name for node in nodes]) + # filter out falsey values from nodes + filtered_nodes = [node for node in nodes if node.name] + name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes]) for node, name_embedding in zip(nodes, name_embeddings, strict=True): node.name_embedding = name_embedding diff --git a/pyproject.toml b/pyproject.toml index 5dacc78c..939eb0ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.21.0pre10" +version = "0.21.0pre11" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index a67aa561..7f362eec 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10, <4" resolution-markers = [ "python_full_version >= '3.14'", @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.21.0rc8" +version = "0.21.0rc11" source = { editable = "." } dependencies = [ { name = "diskcache" },