feat: Add optional callback to control node summary generation (#959)

Add NodeSummaryFilter callback parameter to extract_attributes_from_nodes
and extract_attributes_from_node functions, allowing consumers to
selectively skip summary regeneration for specific nodes.

This enables downstream applications to implement custom logic for
throttling or filtering which nodes should have summaries regenerated,
reducing unnecessary LLM calls and token costs.

Key changes:
- Add NodeSummaryFilter type alias: Callable[[EntityNode], Awaitable[bool]]
- Update extract_attributes_from_nodes with optional should_summarize_node parameter
- Update extract_attributes_from_node with conditional summary generation logic
- Add 5 comprehensive test cases covering callback functionality
- Maintain full backwards compatibility (default None = all summaries generated)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-10-01 16:17:48 -07:00 committed by GitHub
parent 4a9bcd5b10
commit 644aa2b967
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 201 additions and 7 deletions

View file

@ -15,6 +15,7 @@ limitations under the License.
"""
import logging
from collections.abc import Awaitable, Callable
from time import time
from typing import Any
@ -55,6 +56,8 @@ from graphiti_core.utils.maintenance.edge_operations import (
logger = logging.getLogger(__name__)
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
async def extract_nodes_reflexion(
llm_client: LLMClient,
@ -402,6 +405,7 @@ async def extract_attributes_from_nodes(
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
should_summarize_node: NodeSummaryFilter | None = None,
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
@ -418,6 +422,7 @@ async def extract_attributes_from_nodes(
else None
),
clients.ensure_ascii,
should_summarize_node,
)
for node in nodes
]
@ -435,6 +440,7 @@ async def extract_attributes_from_node(
previous_episodes: list[EpisodicNode] | None = None,
entity_type: type[BaseModel] | None = None,
ensure_ascii: bool = False,
should_summarize_node: NodeSummaryFilter | None = None,
) -> EntityNode:
node_context: dict[str, Any] = {
'name': node.name,
@ -477,16 +483,22 @@ async def extract_attributes_from_node(
else {}
)
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
)
# Determine if summary should be generated
generate_summary = True
if should_summarize_node is not None:
generate_summary = await should_summarize_node(node)
# Conditionally generate summary
if generate_summary:
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
)
node.summary = summary_response.get('summary', '')
if has_entity_attributes and entity_type is not None:
entity_type(**llm_response)
node.summary = summary_response.get('summary', '')
node_attributes = {key: value for key, value in llm_response.items()}
node.attributes.update(node_attributes)

View file

@ -27,6 +27,8 @@ from graphiti_core.utils.maintenance.dedup_helpers import (
from graphiti_core.utils.maintenance.node_operations import (
_collect_candidate_nodes,
_resolve_with_llm,
extract_attributes_from_node,
extract_attributes_from_nodes,
resolve_extracted_nodes,
)
@ -477,3 +479,183 @@ async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monk
assert state.resolved_nodes[0] == extracted
assert state.uuid_map[extracted.uuid] == extracted.uuid
assert state.duplicate_pairs == []
@pytest.mark.asyncio
async def test_extract_attributes_without_callback_generates_summary():
"""Test that summary is generated when no callback is provided (default behavior)."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=None, # No callback provided
)
# Summary should be generated
assert result.summary == 'Generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_skip_summary():
"""Test that summary is NOT regenerated when callback returns False."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'This should not be used', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns False (skip summary generation)
async def skip_summary_filter(node: EntityNode) -> bool:
return False
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=skip_summary_filter,
)
# Summary should remain unchanged
assert result.summary == 'Old summary'
# LLM should NOT have been called for summary
assert llm_client.generate_response.call_count == 0
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_generate_summary():
"""Test that summary is regenerated when callback returns True."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'New generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns True (generate summary)
async def generate_summary_filter(node: EntityNode) -> bool:
return True
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=generate_summary_filter,
)
# Summary should be updated
assert result.summary == 'New generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_selective_callback():
"""Test callback that selectively skips summaries based on node properties."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
topic_node = EntityNode(
name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
)
episode = _make_episode()
# Callback that skips User nodes but generates for others
async def selective_filter(node: EntityNode) -> bool:
return 'User' not in node.labels
result_user = await extract_attributes_from_node(
llm_client,
user_node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=selective_filter,
)
result_topic = await extract_attributes_from_node(
llm_client,
topic_node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=selective_filter,
)
# User summary should remain unchanged
assert result_user.summary == 'Old'
# Topic summary should be generated
assert result_topic.summary == 'Generated summary'
# LLM should have been called only once (for topic)
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_from_nodes_with_callback():
"""Test that callback is properly passed through extract_attributes_from_nodes."""
clients, _ = _make_clients()
clients.llm_client.generate_response = AsyncMock(
return_value={'summary': 'New summary', 'attributes': {}}
)
clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
episode = _make_episode()
call_tracker = []
# Callback that tracks which nodes it's called with
async def tracking_filter(node: EntityNode) -> bool:
call_tracker.append(node.name)
return 'User' not in node.labels
results = await extract_attributes_from_nodes(
clients,
[node1, node2],
episode=episode,
previous_episodes=[],
entity_types=None,
should_summarize_node=tracking_filter,
)
# Callback should have been called for both nodes
assert len(call_tracker) == 2
assert 'Node1' in call_tracker
assert 'Node2' in call_tracker
# Node1 (User) should keep old summary, Node2 (Topic) should get new summary
node1_result = next(n for n in results if n.name == 'Node1')
node2_result = next(n for n in results if n.name == 'Node2')
assert node1_result.summary == 'Old1'
assert node2_result.summary == 'New summary'