feat: Add optional callback to control node summary generation (#959)
Add NodeSummaryFilter callback parameter to extract_attributes_from_nodes and extract_attributes_from_node functions, allowing consumers to selectively skip summary regeneration for specific nodes. This enables downstream applications to implement custom logic for throttling or filtering which nodes should have summaries regenerated, reducing unnecessary LLM calls and token costs. Key changes: - Add NodeSummaryFilter type alias: Callable[[EntityNode], Awaitable[bool]] - Update extract_attributes_from_nodes with optional should_summarize_node parameter - Update extract_attributes_from_node with conditional summary generation logic - Add 5 comprehensive test cases covering callback functionality - Maintain full backwards compatibility (default None = all summaries generated) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4a9bcd5b10
commit
644aa2b967
2 changed files with 201 additions and 7 deletions
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -55,6 +56,8 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
|
||||
|
||||
|
||||
async def extract_nodes_reflexion(
|
||||
llm_client: LLMClient,
|
||||
|
|
@ -402,6 +405,7 @@ async def extract_attributes_from_nodes(
|
|||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
should_summarize_node: NodeSummaryFilter | None = None,
|
||||
) -> list[EntityNode]:
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
|
@ -418,6 +422,7 @@ async def extract_attributes_from_nodes(
|
|||
else None
|
||||
),
|
||||
clients.ensure_ascii,
|
||||
should_summarize_node,
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
|
@ -435,6 +440,7 @@ async def extract_attributes_from_node(
|
|||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_type: type[BaseModel] | None = None,
|
||||
ensure_ascii: bool = False,
|
||||
should_summarize_node: NodeSummaryFilter | None = None,
|
||||
) -> EntityNode:
|
||||
node_context: dict[str, Any] = {
|
||||
'name': node.name,
|
||||
|
|
@ -477,16 +483,22 @@ async def extract_attributes_from_node(
|
|||
else {}
|
||||
)
|
||||
|
||||
summary_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
# Determine if summary should be generated
|
||||
generate_summary = True
|
||||
if should_summarize_node is not None:
|
||||
generate_summary = await should_summarize_node(node)
|
||||
|
||||
# Conditionally generate summary
|
||||
if generate_summary:
|
||||
summary_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
node.summary = summary_response.get('summary', '')
|
||||
|
||||
if has_entity_attributes and entity_type is not None:
|
||||
entity_type(**llm_response)
|
||||
|
||||
node.summary = summary_response.get('summary', '')
|
||||
node_attributes = {key: value for key, value in llm_response.items()}
|
||||
|
||||
node.attributes.update(node_attributes)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ from graphiti_core.utils.maintenance.dedup_helpers import (
|
|||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
_collect_candidate_nodes,
|
||||
_resolve_with_llm,
|
||||
extract_attributes_from_node,
|
||||
extract_attributes_from_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
|
||||
|
|
@ -477,3 +479,183 @@ async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monk
|
|||
assert state.resolved_nodes[0] == extracted
|
||||
assert state.uuid_map[extracted.uuid] == extracted.uuid
|
||||
assert state.duplicate_pairs == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_without_callback_generates_summary():
|
||||
"""Test that summary is generated when no callback is provided (default behavior)."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'Generated summary', 'attributes': {}}
|
||||
)
|
||||
|
||||
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
|
||||
episode = _make_episode()
|
||||
|
||||
result = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=None, # No callback provided
|
||||
)
|
||||
|
||||
# Summary should be generated
|
||||
assert result.summary == 'Generated summary'
|
||||
# LLM should have been called for summary
|
||||
assert llm_client.generate_response.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_with_callback_skip_summary():
|
||||
"""Test that summary is NOT regenerated when callback returns False."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'This should not be used', 'attributes': {}}
|
||||
)
|
||||
|
||||
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
|
||||
episode = _make_episode()
|
||||
|
||||
# Callback that always returns False (skip summary generation)
|
||||
async def skip_summary_filter(node: EntityNode) -> bool:
|
||||
return False
|
||||
|
||||
result = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=skip_summary_filter,
|
||||
)
|
||||
|
||||
# Summary should remain unchanged
|
||||
assert result.summary == 'Old summary'
|
||||
# LLM should NOT have been called for summary
|
||||
assert llm_client.generate_response.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_with_callback_generate_summary():
|
||||
"""Test that summary is regenerated when callback returns True."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'New generated summary', 'attributes': {}}
|
||||
)
|
||||
|
||||
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
|
||||
episode = _make_episode()
|
||||
|
||||
# Callback that always returns True (generate summary)
|
||||
async def generate_summary_filter(node: EntityNode) -> bool:
|
||||
return True
|
||||
|
||||
result = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=generate_summary_filter,
|
||||
)
|
||||
|
||||
# Summary should be updated
|
||||
assert result.summary == 'New generated summary'
|
||||
# LLM should have been called for summary
|
||||
assert llm_client.generate_response.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_with_selective_callback():
|
||||
"""Test callback that selectively skips summaries based on node properties."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'Generated summary', 'attributes': {}}
|
||||
)
|
||||
|
||||
user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
|
||||
topic_node = EntityNode(
|
||||
name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
|
||||
)
|
||||
|
||||
episode = _make_episode()
|
||||
|
||||
# Callback that skips User nodes but generates for others
|
||||
async def selective_filter(node: EntityNode) -> bool:
|
||||
return 'User' not in node.labels
|
||||
|
||||
result_user = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
user_node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=selective_filter,
|
||||
)
|
||||
|
||||
result_topic = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
topic_node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=selective_filter,
|
||||
)
|
||||
|
||||
# User summary should remain unchanged
|
||||
assert result_user.summary == 'Old'
|
||||
# Topic summary should be generated
|
||||
assert result_topic.summary == 'Generated summary'
|
||||
# LLM should have been called only once (for topic)
|
||||
assert llm_client.generate_response.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_from_nodes_with_callback():
|
||||
"""Test that callback is properly passed through extract_attributes_from_nodes."""
|
||||
clients, _ = _make_clients()
|
||||
clients.llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'New summary', 'attributes': {}}
|
||||
)
|
||||
clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
|
||||
|
||||
node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
|
||||
node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
|
||||
|
||||
episode = _make_episode()
|
||||
|
||||
call_tracker = []
|
||||
|
||||
# Callback that tracks which nodes it's called with
|
||||
async def tracking_filter(node: EntityNode) -> bool:
|
||||
call_tracker.append(node.name)
|
||||
return 'User' not in node.labels
|
||||
|
||||
results = await extract_attributes_from_nodes(
|
||||
clients,
|
||||
[node1, node2],
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_types=None,
|
||||
should_summarize_node=tracking_filter,
|
||||
)
|
||||
|
||||
# Callback should have been called for both nodes
|
||||
assert len(call_tracker) == 2
|
||||
assert 'Node1' in call_tracker
|
||||
assert 'Node2' in call_tracker
|
||||
|
||||
# Node1 (User) should keep old summary, Node2 (Topic) should get new summary
|
||||
node1_result = next(n for n in results if n.name == 'Node1')
|
||||
node2_result = next(n for n in results if n.name == 'Node2')
|
||||
|
||||
assert node1_result.summary == 'Old1'
|
||||
assert node2_result.summary == 'New summary'
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue