fix: Improve edge extraction entity ID validation (#968)
* fix: Improve edge extraction entity ID validation
Fixes invalid entity ID references in edge extraction that caused warnings like:
"WARNING: source or target node not filled WILL_FIND. source_node_uuid: 23 and target_node_uuid: 3"
Changes:
- Format ENTITIES list as proper JSON in prompt for better LLM parsing
- Clarify field descriptions to reference entity id from ENTITIES list
- Add explicit entity ID validation as #1 extraction rule with examples
- Improve error logging (removed PII, added entity count and valid range)
These changes follow patterns from extract_nodes.py and dedupe_nodes.py where
entity referencing works reliably.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
* wip
* fix: Align fact field naming and add description
- Change extraction rule to reference 'fact' instead of 'fact_text'
- Add descriptive text for fact field in Edge model
* fix: Remove ensure_ascii parameter from to_prompt_json call
Align with other to_prompt_json calls that don't use ensure_ascii
* fix: Use validated target_node_idx variable consistently
Line 190 was using raw edge_data.target_entity_id instead of the
validated target_node_idx variable, creating inconsistency with line 189
* fix: Improve edge extraction validation checks
- Add explicit check for empty nodes list
- Use more explicit 0 <= idx comparison instead of -1 < idx
- Prevents nonsensical error message when no entities provided
* chore: Restore uv.lock from main branch
Previously deleted in commit 7e4464b, now restored to match main branch state
* Update uv.lock
---------
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4a307dbf10
commit
590282524a
3 changed files with 24 additions and 15 deletions
|
|
@ -24,9 +24,16 @@ from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
|
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
|
||||||
source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
|
source_entity_id: int = Field(
|
||||||
target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
|
..., description='The id of the source entity from the ENTITIES list'
|
||||||
fact: str = Field(..., description='')
|
)
|
||||||
|
target_entity_id: int = Field(
|
||||||
|
..., description='The id of the target entity from the ENTITIES list'
|
||||||
|
)
|
||||||
|
fact: str = Field(
|
||||||
|
...,
|
||||||
|
description='A natural language description of the relationship between the entities, paraphrased from the source text',
|
||||||
|
)
|
||||||
valid_at: str | None = Field(
|
valid_at: str | None = Field(
|
||||||
None,
|
None,
|
||||||
description='The date and time when the relationship described by the edge fact became true or was established. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
|
description='The date and time when the relationship described by the edge fact became true or was established. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
|
||||||
|
|
@ -81,7 +88,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
</CURRENT_MESSAGE>
|
</CURRENT_MESSAGE>
|
||||||
|
|
||||||
<ENTITIES>
|
<ENTITIES>
|
||||||
{context['nodes']}
|
{to_prompt_json(context['nodes'], indent=2)}
|
||||||
</ENTITIES>
|
</ENTITIES>
|
||||||
|
|
||||||
<REFERENCE_TIME>
|
<REFERENCE_TIME>
|
||||||
|
|
@ -107,11 +114,12 @@ You may use information from the PREVIOUS MESSAGES only to disambiguate referenc
|
||||||
|
|
||||||
# EXTRACTION RULES
|
# EXTRACTION RULES
|
||||||
|
|
||||||
1. Only emit facts where both the subject and object match IDs in ENTITIES.
|
1. **Entity ID Validation**: `source_entity_id` and `target_entity_id` must use only the `id` values from the ENTITIES list provided above.
|
||||||
|
- **CRITICAL**: Using IDs not in the list will cause the edge to be rejected
|
||||||
2. Each fact must involve two **distinct** entities.
|
2. Each fact must involve two **distinct** entities.
|
||||||
3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT).
|
3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT).
|
||||||
4. Do not emit duplicate or semantically redundant facts.
|
4. Do not emit duplicate or semantically redundant facts.
|
||||||
5. The `fact_text` should closely paraphrase the original source sentence(s). Do not verbatim quote the original text.
|
5. The `fact` should closely paraphrase the original source sentence(s). Do not verbatim quote the original text.
|
||||||
6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week").
|
6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week").
|
||||||
7. Do **not** hallucinate or infer temporal bounds from unrelated events.
|
7. Do **not** hallucinate or infer temporal bounds from unrelated events.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -179,13 +179,20 @@ async def extract_edges(
|
||||||
|
|
||||||
source_node_idx = edge_data.source_entity_id
|
source_node_idx = edge_data.source_entity_id
|
||||||
target_node_idx = edge_data.target_entity_id
|
target_node_idx = edge_data.target_entity_id
|
||||||
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
logger.warning('No entities provided for edge extraction')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'WARNING: source or target node not filled {edge_data.relation_type}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
|
||||||
|
f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
|
||||||
|
f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
source_node_uuid = nodes[source_node_idx].uuid
|
source_node_uuid = nodes[source_node_idx].uuid
|
||||||
target_node_uuid = nodes[edge_data.target_entity_id].uuid
|
target_node_uuid = nodes[target_node_idx].uuid
|
||||||
|
|
||||||
if valid_at:
|
if valid_at:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
6
uv.lock
generated
6
uv.lock
generated
|
|
@ -803,7 +803,6 @@ anthropic = [
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
{ name = "anthropic" },
|
{ name = "anthropic" },
|
||||||
{ name = "boto3" },
|
|
||||||
{ name = "diskcache-stubs" },
|
{ name = "diskcache-stubs" },
|
||||||
{ name = "falkordb" },
|
{ name = "falkordb" },
|
||||||
{ name = "google-genai" },
|
{ name = "google-genai" },
|
||||||
|
|
@ -812,11 +811,9 @@ dev = [
|
||||||
{ name = "jupyterlab" },
|
{ name = "jupyterlab" },
|
||||||
{ name = "kuzu" },
|
{ name = "kuzu" },
|
||||||
{ name = "langchain-anthropic" },
|
{ name = "langchain-anthropic" },
|
||||||
{ name = "langchain-aws" },
|
|
||||||
{ name = "langchain-openai" },
|
{ name = "langchain-openai" },
|
||||||
{ name = "langgraph" },
|
{ name = "langgraph" },
|
||||||
{ name = "langsmith" },
|
{ name = "langsmith" },
|
||||||
{ name = "opensearch-py" },
|
|
||||||
{ name = "pyright" },
|
{ name = "pyright" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
|
|
@ -858,7 +855,6 @@ voyageai = [
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
||||||
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
||||||
{ name = "boto3", marker = "extra == 'dev'", specifier = ">=1.39.16" },
|
|
||||||
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
||||||
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
||||||
{ name = "diskcache", specifier = ">=5.6.3" },
|
{ name = "diskcache", specifier = ">=5.6.3" },
|
||||||
|
|
@ -874,7 +870,6 @@ requires-dist = [
|
||||||
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
{ name = "kuzu", marker = "extra == 'dev'", specifier = ">=0.11.2" },
|
||||||
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
{ name = "kuzu", marker = "extra == 'kuzu'", specifier = ">=0.11.2" },
|
||||||
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
{ name = "langchain-anthropic", marker = "extra == 'dev'", specifier = ">=0.2.4" },
|
||||||
{ name = "langchain-aws", marker = "extra == 'dev'", specifier = ">=0.2.29" },
|
|
||||||
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
{ name = "langchain-aws", marker = "extra == 'neptune'", specifier = ">=0.2.29" },
|
||||||
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
{ name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.2.6" },
|
||||||
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
|
{ name = "langgraph", marker = "extra == 'dev'", specifier = ">=0.2.15" },
|
||||||
|
|
@ -882,7 +877,6 @@ requires-dist = [
|
||||||
{ name = "neo4j", specifier = ">=5.26.0" },
|
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||||
{ name = "numpy", specifier = ">=1.0.0" },
|
{ name = "numpy", specifier = ">=1.0.0" },
|
||||||
{ name = "openai", specifier = ">=1.91.0" },
|
{ name = "openai", specifier = ">=1.91.0" },
|
||||||
{ name = "opensearch-py", marker = "extra == 'dev'", specifier = ">=3.0.0" },
|
|
||||||
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
||||||
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
||||||
{ name = "posthog", specifier = ">=3.0.0" },
|
{ name = "posthog", specifier = ">=3.0.0" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue