validate pydantic objects (#783)
* validate pydantic objects * unused imports * linter
This commit is contained in:
parent
78731316ce
commit
19bddb5528
4 changed files with 46 additions and 119 deletions
|
|
@ -68,6 +68,10 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
<FACT TYPES>
|
||||
{context['edge_types']}
|
||||
</FACT TYPES>
|
||||
|
||||
<PREVIOUS_MESSAGES>
|
||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
||||
</PREVIOUS_MESSAGES>
|
||||
|
|
@ -84,10 +88,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|||
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
||||
</REFERENCE_TIME>
|
||||
|
||||
<FACT TYPES>
|
||||
{context['edge_types']}
|
||||
</FACT TYPES>
|
||||
|
||||
# TASK
|
||||
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
||||
Only extract facts that:
|
||||
|
|
|
|||
|
|
@ -75,6 +75,10 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
|||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
|
||||
|
||||
user_prompt = f"""
|
||||
<ENTITY TYPES>
|
||||
{context['entity_types']}
|
||||
</ENTITY TYPES>
|
||||
|
||||
<PREVIOUS MESSAGES>
|
||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
||||
</PREVIOUS MESSAGES>
|
||||
|
|
@ -83,10 +87,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
|||
{context['episode_content']}
|
||||
</CURRENT MESSAGE>
|
||||
|
||||
<ENTITY TYPES>
|
||||
{context['entity_types']}
|
||||
</ENTITY TYPES>
|
||||
|
||||
Instructions:
|
||||
|
||||
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
||||
|
|
@ -124,15 +124,16 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
|
|||
Your primary task is to extract and classify relevant entities from JSON files"""
|
||||
|
||||
user_prompt = f"""
|
||||
<ENTITY TYPES>
|
||||
{context['entity_types']}
|
||||
</ENTITY TYPES>
|
||||
|
||||
<SOURCE DESCRIPTION>:
|
||||
{context['source_description']}
|
||||
</SOURCE DESCRIPTION>
|
||||
<JSON>
|
||||
{context['episode_content']}
|
||||
</JSON>
|
||||
<ENTITY TYPES>
|
||||
{context['entity_types']}
|
||||
</ENTITY TYPES>
|
||||
|
||||
{context['custom_prompt']}
|
||||
|
||||
|
|
@ -155,13 +156,14 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
|
|||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
|
||||
|
||||
user_prompt = f"""
|
||||
<TEXT>
|
||||
{context['episode_content']}
|
||||
</TEXT>
|
||||
<ENTITY TYPES>
|
||||
{context['entity_types']}
|
||||
</ENTITY TYPES>
|
||||
|
||||
<TEXT>
|
||||
{context['episode_content']}
|
||||
</TEXT>
|
||||
|
||||
Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
|
||||
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
|
||||
Indicate the classified entity type by providing its entity_type_id.
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from graphiti_core.llm_client import LLMClient
|
|||
from graphiti_core.llm_client.config import ModelSize
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
||||
|
|
@ -161,9 +161,9 @@ async def extract_edges(
|
|||
response_model=ExtractedEdges,
|
||||
max_tokens=extract_edges_max_tokens,
|
||||
)
|
||||
edges_data = llm_response.get('edges', [])
|
||||
edges_data = ExtractedEdges(**llm_response).edges
|
||||
|
||||
context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
|
||||
context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
|
||||
|
||||
reflexion_iterations += 1
|
||||
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||
|
|
@ -193,20 +193,20 @@ async def extract_edges(
|
|||
edges = []
|
||||
for edge_data in edges_data:
|
||||
# Validate Edge Date information
|
||||
valid_at = edge_data.get('valid_at', None)
|
||||
invalid_at = edge_data.get('invalid_at', None)
|
||||
valid_at = edge_data.valid_at
|
||||
invalid_at = edge_data.invalid_at
|
||||
valid_at_datetime = None
|
||||
invalid_at_datetime = None
|
||||
|
||||
source_node_idx = edge_data.get('source_entity_id', -1)
|
||||
target_node_idx = edge_data.get('target_entity_id', -1)
|
||||
source_node_idx = edge_data.source_entity_id
|
||||
target_node_idx = edge_data.target_entity_id
|
||||
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
|
||||
logger.warning(
|
||||
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
||||
f'WARNING: source or target node not filled {edge_data.relation_type}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
||||
)
|
||||
continue
|
||||
source_node_uuid = nodes[source_node_idx].uuid
|
||||
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
|
||||
target_node_uuid = nodes[edge_data.target_entity_id].uuid
|
||||
|
||||
if valid_at:
|
||||
try:
|
||||
|
|
@ -226,9 +226,9 @@ async def extract_edges(
|
|||
edge = EntityEdge(
|
||||
source_node_uuid=source_node_uuid,
|
||||
target_node_uuid=target_node_uuid,
|
||||
name=edge_data.get('relation_type', ''),
|
||||
name=edge_data.relation_type,
|
||||
group_id=group_id,
|
||||
fact=edge_data.get('fact', ''),
|
||||
fact=edge_data.fact,
|
||||
episodes=[episode.uuid],
|
||||
created_at=utc_now(),
|
||||
valid_at=valid_at_datetime,
|
||||
|
|
@ -422,10 +422,10 @@ async def resolve_extracted_edge(
|
|||
response_model=EdgeDuplicate,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
response_object = EdgeDuplicate(**llm_response)
|
||||
duplicate_facts = response_object.duplicate_facts
|
||||
|
||||
duplicate_fact_ids: list[int] = list(
|
||||
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
|
||||
)
|
||||
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
|
||||
|
||||
resolved_edge = extracted_edge
|
||||
for duplicate_fact_id in duplicate_fact_ids:
|
||||
|
|
@ -435,11 +435,13 @@ async def resolve_extracted_edge(
|
|||
if duplicate_fact_ids and episode is not None:
|
||||
resolved_edge.episodes.append(episode.uuid)
|
||||
|
||||
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
||||
contradicted_facts: list[int] = response_object.contradicted_facts
|
||||
|
||||
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
||||
invalidation_candidates: list[EntityEdge] = [
|
||||
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
|
||||
]
|
||||
|
||||
fact_type: str = str(llm_response.get('fact_type'))
|
||||
fact_type: str = response_object.fact_type
|
||||
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
||||
resolved_edge.name = fact_type
|
||||
|
||||
|
|
@ -494,39 +496,6 @@ async def resolve_extracted_edge(
|
|||
return resolved_edge, invalidated_edges, duplicate_edges
|
||||
|
||||
|
||||
async def dedupe_edge_list(
|
||||
llm_client: LLMClient,
|
||||
edges: list[EntityEdge],
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
# Create edge map
|
||||
edge_map = {}
|
||||
for edge in edges:
|
||||
edge_map[edge.uuid] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
|
||||
)
|
||||
unique_edges_data = llm_response.get('unique_facts', [])
|
||||
|
||||
end = time()
|
||||
logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
||||
|
||||
# Get full edge data
|
||||
unique_edges = []
|
||||
for edge_data in unique_edges_data:
|
||||
uuid = edge_data['uuid']
|
||||
edge = edge_map[uuid]
|
||||
edge.fact = edge_data['fact']
|
||||
unique_edges.append(edge)
|
||||
|
||||
return unique_edges
|
||||
|
||||
|
||||
async def filter_existing_duplicate_of_edges(
|
||||
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
||||
) -> list[tuple[EntityNode, EntityNode]]:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
|
|||
from graphiti_core.llm_client.config import ModelSize
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
||||
from graphiti_core.prompts.extract_nodes import (
|
||||
ExtractedEntities,
|
||||
ExtractedEntity,
|
||||
|
|
@ -125,10 +125,9 @@ async def extract_nodes(
|
|||
prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
|
||||
)
|
||||
|
||||
extracted_entities: list[ExtractedEntity] = [
|
||||
ExtractedEntity(**entity_types_context)
|
||||
for entity_types_context in llm_response.get('extracted_entities', [])
|
||||
]
|
||||
response_object = ExtractedEntities(**llm_response)
|
||||
|
||||
extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
|
||||
|
||||
reflexion_iterations += 1
|
||||
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||
|
|
@ -254,14 +253,14 @@ async def resolve_extracted_nodes(
|
|||
response_model=NodeResolutions,
|
||||
)
|
||||
|
||||
node_resolutions: list = llm_response.get('entity_resolutions', [])
|
||||
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
||||
|
||||
resolved_nodes: list[EntityNode] = []
|
||||
uuid_map: dict[str, str] = {}
|
||||
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
||||
for resolution in node_resolutions:
|
||||
resolution_id: int = resolution.get('id', -1)
|
||||
duplicate_idx: int = resolution.get('duplicate_idx', -1)
|
||||
resolution_id: int = resolution.id
|
||||
duplicate_idx: int = resolution.duplicate_idx
|
||||
|
||||
extracted_node = extracted_nodes[resolution_id]
|
||||
|
||||
|
|
@ -276,7 +275,7 @@ async def resolve_extracted_nodes(
|
|||
resolved_nodes.append(resolved_node)
|
||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||
|
||||
duplicates: list[int] = resolution.get('duplicates', [])
|
||||
duplicates: list[int] = resolution.duplicates
|
||||
if duplicate_idx not in duplicates and duplicate_idx > -1:
|
||||
duplicates.append(duplicate_idx)
|
||||
for idx in duplicates:
|
||||
|
|
@ -369,7 +368,9 @@ async def extract_attributes_from_node(
|
|||
model_size=ModelSize.small,
|
||||
)
|
||||
|
||||
node.summary = llm_response.get('summary', node.summary)
|
||||
entity_attributes_model(**llm_response)
|
||||
|
||||
node.summary = llm_response.get('summary', '')
|
||||
node_attributes = {key: value for key, value in llm_response.items()}
|
||||
|
||||
with suppress(KeyError):
|
||||
|
|
@ -378,48 +379,3 @@ async def extract_attributes_from_node(
|
|||
node.attributes.update(node_attributes)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def dedupe_node_list(
|
||||
llm_client: LLMClient,
|
||||
nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
# build node map
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
node_map[node.uuid] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
|
||||
|
||||
context = {
|
||||
'nodes': nodes_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_nodes.node_list(context)
|
||||
)
|
||||
|
||||
nodes_data = llm_response.get('nodes', [])
|
||||
|
||||
end = time()
|
||||
logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
||||
|
||||
# Get full node data
|
||||
unique_nodes = []
|
||||
uuid_map: dict[str, str] = {}
|
||||
for node_data in nodes_data:
|
||||
node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
|
||||
if node_instance is None:
|
||||
logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
|
||||
continue
|
||||
node_instance.summary = node_data['summary']
|
||||
unique_nodes.append(node_instance)
|
||||
|
||||
for uuid in node_data['uuids'][1:]:
|
||||
uuid_value = node_map[node_data['uuids'][0]].uuid
|
||||
uuid_map[uuid] = uuid_value
|
||||
|
||||
return unique_nodes, uuid_map
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue