validate pydantic objects (#783)
* validate pydantic objects * unused imports * linter
This commit is contained in:
parent
78731316ce
commit
19bddb5528
4 changed files with 46 additions and 119 deletions
|
|
@ -68,6 +68,10 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
Message(
|
Message(
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
|
<FACT TYPES>
|
||||||
|
{context['edge_types']}
|
||||||
|
</FACT TYPES>
|
||||||
|
|
||||||
<PREVIOUS_MESSAGES>
|
<PREVIOUS_MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
||||||
</PREVIOUS_MESSAGES>
|
</PREVIOUS_MESSAGES>
|
||||||
|
|
@ -84,10 +88,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
||||||
</REFERENCE_TIME>
|
</REFERENCE_TIME>
|
||||||
|
|
||||||
<FACT TYPES>
|
|
||||||
{context['edge_types']}
|
|
||||||
</FACT TYPES>
|
|
||||||
|
|
||||||
# TASK
|
# TASK
|
||||||
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
||||||
Only extract facts that:
|
Only extract facts that:
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,10 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
||||||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
|
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
|
<ENTITY TYPES>
|
||||||
|
{context['entity_types']}
|
||||||
|
</ENTITY TYPES>
|
||||||
|
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
||||||
</PREVIOUS MESSAGES>
|
</PREVIOUS MESSAGES>
|
||||||
|
|
@ -83,10 +87,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
</CURRENT MESSAGE>
|
</CURRENT MESSAGE>
|
||||||
|
|
||||||
<ENTITY TYPES>
|
|
||||||
{context['entity_types']}
|
|
||||||
</ENTITY TYPES>
|
|
||||||
|
|
||||||
Instructions:
|
Instructions:
|
||||||
|
|
||||||
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
||||||
|
|
@ -124,15 +124,16 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
|
||||||
Your primary task is to extract and classify relevant entities from JSON files"""
|
Your primary task is to extract and classify relevant entities from JSON files"""
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
|
<ENTITY TYPES>
|
||||||
|
{context['entity_types']}
|
||||||
|
</ENTITY TYPES>
|
||||||
|
|
||||||
<SOURCE DESCRIPTION>:
|
<SOURCE DESCRIPTION>:
|
||||||
{context['source_description']}
|
{context['source_description']}
|
||||||
</SOURCE DESCRIPTION>
|
</SOURCE DESCRIPTION>
|
||||||
<JSON>
|
<JSON>
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
</JSON>
|
</JSON>
|
||||||
<ENTITY TYPES>
|
|
||||||
{context['entity_types']}
|
|
||||||
</ENTITY TYPES>
|
|
||||||
|
|
||||||
{context['custom_prompt']}
|
{context['custom_prompt']}
|
||||||
|
|
||||||
|
|
@ -155,13 +156,14 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
|
||||||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
|
Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
<TEXT>
|
|
||||||
{context['episode_content']}
|
|
||||||
</TEXT>
|
|
||||||
<ENTITY TYPES>
|
<ENTITY TYPES>
|
||||||
{context['entity_types']}
|
{context['entity_types']}
|
||||||
</ENTITY TYPES>
|
</ENTITY TYPES>
|
||||||
|
|
||||||
|
<TEXT>
|
||||||
|
{context['episode_content']}
|
||||||
|
</TEXT>
|
||||||
|
|
||||||
Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
|
Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
|
||||||
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
|
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
|
||||||
Indicate the classified entity type by providing its entity_type_id.
|
Indicate the classified entity type by providing its entity_type_id.
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.llm_client.config import ModelSize
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
||||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
||||||
|
|
@ -161,9 +161,9 @@ async def extract_edges(
|
||||||
response_model=ExtractedEdges,
|
response_model=ExtractedEdges,
|
||||||
max_tokens=extract_edges_max_tokens,
|
max_tokens=extract_edges_max_tokens,
|
||||||
)
|
)
|
||||||
edges_data = llm_response.get('edges', [])
|
edges_data = ExtractedEdges(**llm_response).edges
|
||||||
|
|
||||||
context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
|
context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
|
||||||
|
|
||||||
reflexion_iterations += 1
|
reflexion_iterations += 1
|
||||||
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||||
|
|
@ -193,20 +193,20 @@ async def extract_edges(
|
||||||
edges = []
|
edges = []
|
||||||
for edge_data in edges_data:
|
for edge_data in edges_data:
|
||||||
# Validate Edge Date information
|
# Validate Edge Date information
|
||||||
valid_at = edge_data.get('valid_at', None)
|
valid_at = edge_data.valid_at
|
||||||
invalid_at = edge_data.get('invalid_at', None)
|
invalid_at = edge_data.invalid_at
|
||||||
valid_at_datetime = None
|
valid_at_datetime = None
|
||||||
invalid_at_datetime = None
|
invalid_at_datetime = None
|
||||||
|
|
||||||
source_node_idx = edge_data.get('source_entity_id', -1)
|
source_node_idx = edge_data.source_entity_id
|
||||||
target_node_idx = edge_data.get('target_entity_id', -1)
|
target_node_idx = edge_data.target_entity_id
|
||||||
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
|
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
f'WARNING: source or target node not filled {edge_data.relation_type}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
source_node_uuid = nodes[source_node_idx].uuid
|
source_node_uuid = nodes[source_node_idx].uuid
|
||||||
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
|
target_node_uuid = nodes[edge_data.target_entity_id].uuid
|
||||||
|
|
||||||
if valid_at:
|
if valid_at:
|
||||||
try:
|
try:
|
||||||
|
|
@ -226,9 +226,9 @@ async def extract_edges(
|
||||||
edge = EntityEdge(
|
edge = EntityEdge(
|
||||||
source_node_uuid=source_node_uuid,
|
source_node_uuid=source_node_uuid,
|
||||||
target_node_uuid=target_node_uuid,
|
target_node_uuid=target_node_uuid,
|
||||||
name=edge_data.get('relation_type', ''),
|
name=edge_data.relation_type,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
fact=edge_data.get('fact', ''),
|
fact=edge_data.fact,
|
||||||
episodes=[episode.uuid],
|
episodes=[episode.uuid],
|
||||||
created_at=utc_now(),
|
created_at=utc_now(),
|
||||||
valid_at=valid_at_datetime,
|
valid_at=valid_at_datetime,
|
||||||
|
|
@ -422,10 +422,10 @@ async def resolve_extracted_edge(
|
||||||
response_model=EdgeDuplicate,
|
response_model=EdgeDuplicate,
|
||||||
model_size=ModelSize.small,
|
model_size=ModelSize.small,
|
||||||
)
|
)
|
||||||
|
response_object = EdgeDuplicate(**llm_response)
|
||||||
|
duplicate_facts = response_object.duplicate_facts
|
||||||
|
|
||||||
duplicate_fact_ids: list[int] = list(
|
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
|
||||||
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved_edge = extracted_edge
|
resolved_edge = extracted_edge
|
||||||
for duplicate_fact_id in duplicate_fact_ids:
|
for duplicate_fact_id in duplicate_fact_ids:
|
||||||
|
|
@ -435,11 +435,13 @@ async def resolve_extracted_edge(
|
||||||
if duplicate_fact_ids and episode is not None:
|
if duplicate_fact_ids and episode is not None:
|
||||||
resolved_edge.episodes.append(episode.uuid)
|
resolved_edge.episodes.append(episode.uuid)
|
||||||
|
|
||||||
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
contradicted_facts: list[int] = response_object.contradicted_facts
|
||||||
|
|
||||||
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
invalidation_candidates: list[EntityEdge] = [
|
||||||
|
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
|
||||||
|
]
|
||||||
|
|
||||||
fact_type: str = str(llm_response.get('fact_type'))
|
fact_type: str = response_object.fact_type
|
||||||
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
||||||
resolved_edge.name = fact_type
|
resolved_edge.name = fact_type
|
||||||
|
|
||||||
|
|
@ -494,39 +496,6 @@ async def resolve_extracted_edge(
|
||||||
return resolved_edge, invalidated_edges, duplicate_edges
|
return resolved_edge, invalidated_edges, duplicate_edges
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_edge_list(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
edges: list[EntityEdge],
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
start = time()
|
|
||||||
|
|
||||||
# Create edge map
|
|
||||||
edge_map = {}
|
|
||||||
for edge in edges:
|
|
||||||
edge_map[edge.uuid] = edge
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
|
||||||
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
|
||||||
prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
|
|
||||||
)
|
|
||||||
unique_edges_data = llm_response.get('unique_facts', [])
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
|
||||||
|
|
||||||
# Get full edge data
|
|
||||||
unique_edges = []
|
|
||||||
for edge_data in unique_edges_data:
|
|
||||||
uuid = edge_data['uuid']
|
|
||||||
edge = edge_map[uuid]
|
|
||||||
edge.fact = edge_data['fact']
|
|
||||||
unique_edges.append(edge)
|
|
||||||
|
|
||||||
return unique_edges
|
|
||||||
|
|
||||||
|
|
||||||
async def filter_existing_duplicate_of_edges(
|
async def filter_existing_duplicate_of_edges(
|
||||||
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
||||||
) -> list[tuple[EntityNode, EntityNode]]:
|
) -> list[tuple[EntityNode, EntityNode]]:
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.llm_client.config import ModelSize
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
||||||
from graphiti_core.prompts.extract_nodes import (
|
from graphiti_core.prompts.extract_nodes import (
|
||||||
ExtractedEntities,
|
ExtractedEntities,
|
||||||
ExtractedEntity,
|
ExtractedEntity,
|
||||||
|
|
@ -125,10 +125,9 @@ async def extract_nodes(
|
||||||
prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
|
prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
|
||||||
)
|
)
|
||||||
|
|
||||||
extracted_entities: list[ExtractedEntity] = [
|
response_object = ExtractedEntities(**llm_response)
|
||||||
ExtractedEntity(**entity_types_context)
|
|
||||||
for entity_types_context in llm_response.get('extracted_entities', [])
|
extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
|
||||||
]
|
|
||||||
|
|
||||||
reflexion_iterations += 1
|
reflexion_iterations += 1
|
||||||
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||||
|
|
@ -254,14 +253,14 @@ async def resolve_extracted_nodes(
|
||||||
response_model=NodeResolutions,
|
response_model=NodeResolutions,
|
||||||
)
|
)
|
||||||
|
|
||||||
node_resolutions: list = llm_response.get('entity_resolutions', [])
|
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
||||||
|
|
||||||
resolved_nodes: list[EntityNode] = []
|
resolved_nodes: list[EntityNode] = []
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
||||||
for resolution in node_resolutions:
|
for resolution in node_resolutions:
|
||||||
resolution_id: int = resolution.get('id', -1)
|
resolution_id: int = resolution.id
|
||||||
duplicate_idx: int = resolution.get('duplicate_idx', -1)
|
duplicate_idx: int = resolution.duplicate_idx
|
||||||
|
|
||||||
extracted_node = extracted_nodes[resolution_id]
|
extracted_node = extracted_nodes[resolution_id]
|
||||||
|
|
||||||
|
|
@ -276,7 +275,7 @@ async def resolve_extracted_nodes(
|
||||||
resolved_nodes.append(resolved_node)
|
resolved_nodes.append(resolved_node)
|
||||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||||
|
|
||||||
duplicates: list[int] = resolution.get('duplicates', [])
|
duplicates: list[int] = resolution.duplicates
|
||||||
if duplicate_idx not in duplicates and duplicate_idx > -1:
|
if duplicate_idx not in duplicates and duplicate_idx > -1:
|
||||||
duplicates.append(duplicate_idx)
|
duplicates.append(duplicate_idx)
|
||||||
for idx in duplicates:
|
for idx in duplicates:
|
||||||
|
|
@ -369,7 +368,9 @@ async def extract_attributes_from_node(
|
||||||
model_size=ModelSize.small,
|
model_size=ModelSize.small,
|
||||||
)
|
)
|
||||||
|
|
||||||
node.summary = llm_response.get('summary', node.summary)
|
entity_attributes_model(**llm_response)
|
||||||
|
|
||||||
|
node.summary = llm_response.get('summary', '')
|
||||||
node_attributes = {key: value for key, value in llm_response.items()}
|
node_attributes = {key: value for key, value in llm_response.items()}
|
||||||
|
|
||||||
with suppress(KeyError):
|
with suppress(KeyError):
|
||||||
|
|
@ -378,48 +379,3 @@ async def extract_attributes_from_node(
|
||||||
node.attributes.update(node_attributes)
|
node.attributes.update(node_attributes)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_node_list(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
nodes: list[EntityNode],
|
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
||||||
start = time()
|
|
||||||
|
|
||||||
# build node map
|
|
||||||
node_map = {}
|
|
||||||
for node in nodes:
|
|
||||||
node_map[node.uuid] = node
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
|
||||||
nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
|
|
||||||
|
|
||||||
context = {
|
|
||||||
'nodes': nodes_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
|
||||||
prompt_library.dedupe_nodes.node_list(context)
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes_data = llm_response.get('nodes', [])
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
|
||||||
|
|
||||||
# Get full node data
|
|
||||||
unique_nodes = []
|
|
||||||
uuid_map: dict[str, str] = {}
|
|
||||||
for node_data in nodes_data:
|
|
||||||
node_instance: EntityNode | None = node_map.get(node_data['uuids'][0])
|
|
||||||
if node_instance is None:
|
|
||||||
logger.warning(f'Node {node_data["uuids"][0]} not found in node map')
|
|
||||||
continue
|
|
||||||
node_instance.summary = node_data['summary']
|
|
||||||
unique_nodes.append(node_instance)
|
|
||||||
|
|
||||||
for uuid in node_data['uuids'][1:]:
|
|
||||||
uuid_value = node_map[node_data['uuids'][0]].uuid
|
|
||||||
uuid_map[uuid] = uuid_value
|
|
||||||
|
|
||||||
return unique_nodes, uuid_map
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue