validate pydantic objects (#783)

* validate pydantic objects

* unused imports

* linter
This commit is contained in:
Preston Rasmussen 2025-07-29 17:54:09 -04:00 committed by GitHub
parent 78731316ce
commit 19bddb5528
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 46 additions and 119 deletions

View file

@ -68,6 +68,10 @@ def edge(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
<FACT TYPES>
{context['edge_types']}
</FACT TYPES>
<PREVIOUS_MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS_MESSAGES>
@ -84,10 +88,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
</REFERENCE_TIME>
<FACT TYPES>
{context['edge_types']}
</FACT TYPES>
# TASK
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
Only extract facts that:

View file

@ -75,6 +75,10 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
user_prompt = f"""
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
<PREVIOUS MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS MESSAGES>
@ -83,10 +87,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
{context['episode_content']}
</CURRENT MESSAGE>
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
Instructions:
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
@ -124,15 +124,16 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
Your primary task is to extract and classify relevant entities from JSON files"""
user_prompt = f"""
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
<SOURCE DESCRIPTION>:
{context['source_description']}
</SOURCE DESCRIPTION>
<JSON>
{context['episode_content']}
</JSON>
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
{context['custom_prompt']}
@ -155,13 +156,14 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
user_prompt = f"""
<TEXT>
{context['episode_content']}
</TEXT>
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
<TEXT>
{context['episode_content']}
</TEXT>
Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
Indicate the classified entity type by providing its entity_type_id.

View file

@ -34,7 +34,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
@ -161,9 +161,9 @@ async def extract_edges(
response_model=ExtractedEdges,
max_tokens=extract_edges_max_tokens,
)
edges_data = llm_response.get('edges', [])
edges_data = ExtractedEdges(**llm_response).edges
context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@ -193,20 +193,20 @@ async def extract_edges(
edges = []
for edge_data in edges_data:
# Validate Edge Date information
valid_at = edge_data.get('valid_at', None)
invalid_at = edge_data.get('invalid_at', None)
valid_at = edge_data.valid_at
invalid_at = edge_data.invalid_at
valid_at_datetime = None
invalid_at_datetime = None
source_node_idx = edge_data.get('source_entity_id', -1)
target_node_idx = edge_data.get('target_entity_id', -1)
source_node_idx = edge_data.source_entity_id
target_node_idx = edge_data.target_entity_id
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
logger.warning(
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
f'WARNING: source or target node not filled {edge_data.relation_type}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
)
continue
source_node_uuid = nodes[source_node_idx].uuid
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
target_node_uuid = nodes[edge_data.target_entity_id].uuid
if valid_at:
try:
@ -226,9 +226,9 @@ async def extract_edges(
edge = EntityEdge(
source_node_uuid=source_node_uuid,
target_node_uuid=target_node_uuid,
name=edge_data.get('relation_type', ''),
name=edge_data.relation_type,
group_id=group_id,
fact=edge_data.get('fact', ''),
fact=edge_data.fact,
episodes=[episode.uuid],
created_at=utc_now(),
valid_at=valid_at_datetime,
@ -422,10 +422,10 @@ async def resolve_extracted_edge(
response_model=EdgeDuplicate,
model_size=ModelSize.small,
)
response_object = EdgeDuplicate(**llm_response)
duplicate_facts = response_object.duplicate_facts
duplicate_fact_ids: list[int] = list(
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
)
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
resolved_edge = extracted_edge
for duplicate_fact_id in duplicate_fact_ids:
@ -435,11 +435,13 @@ async def resolve_extracted_edge(
if duplicate_fact_ids and episode is not None:
resolved_edge.episodes.append(episode.uuid)
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
contradicted_facts: list[int] = response_object.contradicted_facts
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
invalidation_candidates: list[EntityEdge] = [
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
]
fact_type: str = str(llm_response.get('fact_type'))
fact_type: str = response_object.fact_type
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
resolved_edge.name = fact_type
@ -494,39 +496,6 @@ async def resolve_extracted_edge(
return resolved_edge, invalidated_edges, duplicate_edges
async def dedupe_edge_list(
llm_client: LLMClient,
edges: list[EntityEdge],
) -> list[EntityEdge]:
start = time()
# Create edge map
edge_map = {}
for edge in edges:
edge_map[edge.uuid] = edge
# Prepare context for LLM
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
)
unique_edges_data = llm_response.get('unique_facts', [])
end = time()
logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
# Get full edge data
unique_edges = []
for edge_data in unique_edges_data:
uuid = edge_data['uuid']
edge = edge_map[uuid]
edge.fact = edge_data['fact']
unique_edges.append(edge)
return unique_edges
async def filter_existing_duplicate_of_edges(
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]:

View file

@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
ExtractedEntities,
ExtractedEntity,
@ -125,10 +125,9 @@ async def extract_nodes(
prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
)
extracted_entities: list[ExtractedEntity] = [
ExtractedEntity(**entity_types_context)
for entity_types_context in llm_response.get('extracted_entities', [])
]
response_object = ExtractedEntities(**llm_response)
extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@ -254,14 +253,14 @@ async def resolve_extracted_nodes(
response_model=NodeResolutions,
)
node_resolutions: list = llm_response.get('entity_resolutions', [])
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
resolved_nodes: list[EntityNode] = []
uuid_map: dict[str, str] = {}
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
for resolution in node_resolutions:
resolution_id: int = resolution.get('id', -1)
duplicate_idx: int = resolution.get('duplicate_idx', -1)
resolution_id: int = resolution.id
duplicate_idx: int = resolution.duplicate_idx
extracted_node = extracted_nodes[resolution_id]
@ -276,7 +275,7 @@ async def resolve_extracted_nodes(
resolved_nodes.append(resolved_node)
uuid_map[extracted_node.uuid] = resolved_node.uuid
duplicates: list[int] = resolution.get('duplicates', [])
duplicates: list[int] = resolution.duplicates
if duplicate_idx not in duplicates and duplicate_idx > -1:
duplicates.append(duplicate_idx)
for idx in duplicates:
@ -369,7 +368,9 @@ async def extract_attributes_from_node(
model_size=ModelSize.small,
)
node.summary = llm_response.get('summary', node.summary)
entity_attributes_model(**llm_response)
node.summary = llm_response.get('summary', '')
node_attributes = {key: value for key, value in llm_response.items()}
with suppress(KeyError):
@ -378,48 +379,3 @@ async def extract_attributes_from_node(
node.attributes.update(node_attributes)
return node
async def dedupe_node_list(
llm_client: LLMClient,
nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
# build node map
node_map = {}
for node in nodes:
node_map[node.uuid] = node
# Prepare context for LLM
nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
context = {
'nodes': nodes_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.node_list(context)
)
nodes_data = llm_response.get('nodes', [])
end = time()
logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
# Get full node data
unique_nodes = []
uuid_map: dict[str, str] = {}
for node_data in nodes_data:
node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
if node_instance is None:
logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
continue
node_instance.summary = node_data['summary']
unique_nodes.append(node_instance)
for uuid in node_data['uuids'][1:]:
uuid_value = node_map[node_data['uuids'][0]].uuid
uuid_map[uuid] = uuid_value
return unique_nodes, uuid_map