From 19bddb55280e72e4e5ab37172f8dcd652d51fd29 Mon Sep 17 00:00:00 2001
From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Date: Tue, 29 Jul 2025 17:54:09 -0400
Subject: [PATCH] validate pydantic objects (#783)
* validate pydantic objects
* unused imports
* linter
---
graphiti_core/prompts/extract_edges.py | 8 +--
graphiti_core/prompts/extract_nodes.py | 22 +++---
.../utils/maintenance/edge_operations.py | 69 +++++--------------
.../utils/maintenance/node_operations.py | 66 +++---------------
4 files changed, 46 insertions(+), 119 deletions(-)
diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py
index 50e039a4..0d002fec 100644
--- a/graphiti_core/prompts/extract_edges.py
+++ b/graphiti_core/prompts/extract_edges.py
@@ -68,6 +68,10 @@ def edge(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
+
+{context['edge_types']}
+
+
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
@@ -84,10 +88,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
-
-{context['edge_types']}
-
-
# TASK
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
Only extract facts that:
diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py
index 6848bb7b..59d07c88 100644
--- a/graphiti_core/prompts/extract_nodes.py
+++ b/graphiti_core/prompts/extract_nodes.py
@@ -75,6 +75,10 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
user_prompt = f"""
+
+{context['entity_types']}
+
+
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
@@ -83,10 +87,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
{context['episode_content']}
-
-{context['entity_types']}
-
-
Instructions:
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
@@ -124,15 +124,16 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
Your primary task is to extract and classify relevant entities from JSON files"""
user_prompt = f"""
+
+{context['entity_types']}
+
+
:
{context['source_description']}
{context['episode_content']}
-
-{context['entity_types']}
-
{context['custom_prompt']}
@@ -155,13 +156,14 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
user_prompt = f"""
-
-{context['episode_content']}
-
{context['entity_types']}
+
+{context['episode_content']}
+
+
Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
Indicate the classified entity type by providing its entity_type_id.
diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py
index ad9267bf..1455653a 100644
--- a/graphiti_core/utils/maintenance/edge_operations.py
+++ b/graphiti_core/utils/maintenance/edge_operations.py
@@ -34,7 +34,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
-from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
+from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
@@ -161,9 +161,9 @@ async def extract_edges(
response_model=ExtractedEdges,
max_tokens=extract_edges_max_tokens,
)
- edges_data = llm_response.get('edges', [])
+ edges_data = ExtractedEdges(**llm_response).edges
- context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
+ context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@@ -193,20 +193,20 @@ async def extract_edges(
edges = []
for edge_data in edges_data:
# Validate Edge Date information
- valid_at = edge_data.get('valid_at', None)
- invalid_at = edge_data.get('invalid_at', None)
+ valid_at = edge_data.valid_at
+ invalid_at = edge_data.invalid_at
valid_at_datetime = None
invalid_at_datetime = None
- source_node_idx = edge_data.get('source_entity_id', -1)
- target_node_idx = edge_data.get('target_entity_id', -1)
+ source_node_idx = edge_data.source_entity_id
+ target_node_idx = edge_data.target_entity_id
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
logger.warning(
- f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
+ f'WARNING: source or target node not filled {edge_data.relation_type}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
)
continue
source_node_uuid = nodes[source_node_idx].uuid
- target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
+ target_node_uuid = nodes[edge_data.target_entity_id].uuid
if valid_at:
try:
@@ -226,9 +226,9 @@ async def extract_edges(
edge = EntityEdge(
source_node_uuid=source_node_uuid,
target_node_uuid=target_node_uuid,
- name=edge_data.get('relation_type', ''),
+ name=edge_data.relation_type,
group_id=group_id,
- fact=edge_data.get('fact', ''),
+ fact=edge_data.fact,
episodes=[episode.uuid],
created_at=utc_now(),
valid_at=valid_at_datetime,
@@ -422,10 +422,10 @@ async def resolve_extracted_edge(
response_model=EdgeDuplicate,
model_size=ModelSize.small,
)
+ response_object = EdgeDuplicate(**llm_response)
+ duplicate_facts = response_object.duplicate_facts
- duplicate_fact_ids: list[int] = list(
- filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
- )
+ duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
resolved_edge = extracted_edge
for duplicate_fact_id in duplicate_fact_ids:
@@ -435,11 +435,13 @@ async def resolve_extracted_edge(
if duplicate_fact_ids and episode is not None:
resolved_edge.episodes.append(episode.uuid)
- contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
+ contradicted_facts: list[int] = response_object.contradicted_facts
- invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
+ invalidation_candidates: list[EntityEdge] = [
+ existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
+ ]
- fact_type: str = str(llm_response.get('fact_type'))
+ fact_type: str = response_object.fact_type
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
resolved_edge.name = fact_type
@@ -494,39 +496,6 @@ async def resolve_extracted_edge(
return resolved_edge, invalidated_edges, duplicate_edges
-async def dedupe_edge_list(
- llm_client: LLMClient,
- edges: list[EntityEdge],
-) -> list[EntityEdge]:
- start = time()
-
- # Create edge map
- edge_map = {}
- for edge in edges:
- edge_map[edge.uuid] = edge
-
- # Prepare context for LLM
- context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
-
- llm_response = await llm_client.generate_response(
- prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
- )
- unique_edges_data = llm_response.get('unique_facts', [])
-
- end = time()
- logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
-
- # Get full edge data
- unique_edges = []
- for edge_data in unique_edges_data:
- uuid = edge_data['uuid']
- edge = edge_map[uuid]
- edge.fact = edge_data['fact']
- unique_edges.append(edge)
-
- return unique_edges
-
-
async def filter_existing_duplicate_of_edges(
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]:
diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py
index 1588d042..9c3256bf 100644
--- a/graphiti_core/utils/maintenance/node_operations.py
+++ b/graphiti_core/utils/maintenance/node_operations.py
@@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
-from graphiti_core.prompts.dedupe_nodes import NodeResolutions
+from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
ExtractedEntities,
ExtractedEntity,
@@ -125,10 +125,9 @@ async def extract_nodes(
prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
)
- extracted_entities: list[ExtractedEntity] = [
- ExtractedEntity(**entity_types_context)
- for entity_types_context in llm_response.get('extracted_entities', [])
- ]
+ response_object = ExtractedEntities(**llm_response)
+
+ extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@@ -254,14 +253,14 @@ async def resolve_extracted_nodes(
response_model=NodeResolutions,
)
- node_resolutions: list = llm_response.get('entity_resolutions', [])
+ node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
resolved_nodes: list[EntityNode] = []
uuid_map: dict[str, str] = {}
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
for resolution in node_resolutions:
- resolution_id: int = resolution.get('id', -1)
- duplicate_idx: int = resolution.get('duplicate_idx', -1)
+ resolution_id: int = resolution.id
+ duplicate_idx: int = resolution.duplicate_idx
extracted_node = extracted_nodes[resolution_id]
@@ -276,7 +275,7 @@ async def resolve_extracted_nodes(
resolved_nodes.append(resolved_node)
uuid_map[extracted_node.uuid] = resolved_node.uuid
- duplicates: list[int] = resolution.get('duplicates', [])
+ duplicates: list[int] = resolution.duplicates
if duplicate_idx not in duplicates and duplicate_idx > -1:
duplicates.append(duplicate_idx)
for idx in duplicates:
@@ -369,7 +368,9 @@ async def extract_attributes_from_node(
model_size=ModelSize.small,
)
- node.summary = llm_response.get('summary', node.summary)
+ entity_attributes_model(**llm_response)
+
+ node.summary = llm_response.get('summary', '')
node_attributes = {key: value for key, value in llm_response.items()}
with suppress(KeyError):
@@ -378,48 +379,3 @@ async def extract_attributes_from_node(
node.attributes.update(node_attributes)
return node
-
-
-async def dedupe_node_list(
- llm_client: LLMClient,
- nodes: list[EntityNode],
-) -> tuple[list[EntityNode], dict[str, str]]:
- start = time()
-
- # build node map
- node_map = {}
- for node in nodes:
- node_map[node.uuid] = node
-
- # Prepare context for LLM
- nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
-
- context = {
- 'nodes': nodes_context,
- }
-
- llm_response = await llm_client.generate_response(
- prompt_library.dedupe_nodes.node_list(context)
- )
-
- nodes_data = llm_response.get('nodes', [])
-
- end = time()
- logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
-
- # Get full node data
- unique_nodes = []
- uuid_map: dict[str, str] = {}
- for node_data in nodes_data:
- node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
- if node_instance is None:
- logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
- continue
- node_instance.summary = node_data['summary']
- unique_nodes.append(node_instance)
-
- for uuid in node_data['uuids'][1:]:
- uuid_value = node_map[node_data['uuids'][0]].uuid
- uuid_map[uuid] = uuid_value
-
- return unique_nodes, uuid_map