diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index bceaaacb..88d2a472 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -644,10 +644,11 @@ def get_community_edge_from_record(record: Any): async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]): - if len(edges) == 0: - return # filter out falsey values from edges filtered_edges = [edge for edge in edges if edge.fact] + + if len(filtered_edges) == 0: + return fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges]) for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True): edge.fact_embedding = fact_embedding diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index bd1c1a92..4105c88e 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -868,11 +868,12 @@ def get_community_node_from_record(record: Any) -> CommunityNode: async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]): - if not nodes: # Handle empty list case - return - # filter out falsey values from nodes filtered_nodes = [node for node in nodes if node.name] + + if not filtered_nodes: + return + name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes]) for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True): node.name_embedding = name_embedding