diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 2d87d31b..24cdc583 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -15,6 +15,7 @@ limitations under the License. """ import logging +from collections.abc import Awaitable, Callable from time import time from typing import Any @@ -55,6 +56,8 @@ from graphiti_core.utils.maintenance.edge_operations import ( logger = logging.getLogger(__name__) +NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]] + async def extract_nodes_reflexion( llm_client: LLMClient, @@ -402,6 +405,7 @@ async def extract_attributes_from_nodes( episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, entity_types: dict[str, type[BaseModel]] | None = None, + should_summarize_node: NodeSummaryFilter | None = None, ) -> list[EntityNode]: llm_client = clients.llm_client embedder = clients.embedder @@ -418,6 +422,7 @@ async def extract_attributes_from_nodes( else None ), clients.ensure_ascii, + should_summarize_node, ) for node in nodes ] @@ -435,6 +440,7 @@ async def extract_attributes_from_node( previous_episodes: list[EpisodicNode] | None = None, entity_type: type[BaseModel] | None = None, ensure_ascii: bool = False, + should_summarize_node: NodeSummaryFilter | None = None, ) -> EntityNode: node_context: dict[str, Any] = { 'name': node.name, @@ -477,16 +483,22 @@ async def extract_attributes_from_node( else {} ) - summary_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_summary(summary_context), - response_model=EntitySummary, - model_size=ModelSize.small, - ) + # Determine if summary should be generated + generate_summary = True + if should_summarize_node is not None: + generate_summary = await should_summarize_node(node) + + # Conditionally generate summary + if generate_summary: + summary_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_summary(summary_context), + response_model=EntitySummary, + model_size=ModelSize.small, + ) + node.summary = summary_response.get('summary', '') if has_entity_attributes and entity_type is not None: entity_type(**llm_response) - - node.summary = summary_response.get('summary', '') node_attributes = {key: value for key, value in llm_response.items()} node.attributes.update(node_attributes) diff --git a/tests/utils/maintenance/test_node_operations.py b/tests/utils/maintenance/test_node_operations.py index 0bbae6b3..c144e1d2 100644 --- a/tests/utils/maintenance/test_node_operations.py +++ b/tests/utils/maintenance/test_node_operations.py @@ -27,6 +27,8 @@ from graphiti_core.utils.maintenance.dedup_helpers import ( from graphiti_core.utils.maintenance.node_operations import ( _collect_candidate_nodes, _resolve_with_llm, + extract_attributes_from_node, + extract_attributes_from_nodes, resolve_extracted_nodes, ) @@ -477,3 +479,183 @@ async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monk assert state.resolved_nodes[0] == extracted assert state.uuid_map[extracted.uuid] == extracted.uuid assert state.duplicate_pairs == [] + + +@pytest.mark.asyncio +async def test_extract_attributes_without_callback_generates_summary(): + """Test that summary is generated when no callback is provided (default behavior).""" + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={'summary': 'Generated summary', 'attributes': {}} + ) + + node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary') + episode = _make_episode() + + result = await extract_attributes_from_node( + llm_client, + node, + episode=episode, + previous_episodes=[], + entity_type=None, + ensure_ascii=False, + should_summarize_node=None, # No callback provided + ) + + # Summary should be generated + assert result.summary == 'Generated summary' + # LLM should have been called for summary + assert llm_client.generate_response.call_count == 1 + + +@pytest.mark.asyncio +async def test_extract_attributes_with_callback_skip_summary(): + """Test that summary is NOT regenerated when callback returns False.""" + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={'summary': 'This should not be used', 'attributes': {}} + ) + + node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary') + episode = _make_episode() + + # Callback that always returns False (skip summary generation) + async def skip_summary_filter(node: EntityNode) -> bool: + return False + + result = await extract_attributes_from_node( + llm_client, + node, + episode=episode, + previous_episodes=[], + entity_type=None, + ensure_ascii=False, + should_summarize_node=skip_summary_filter, + ) + + # Summary should remain unchanged + assert result.summary == 'Old summary' + # LLM should NOT have been called for summary + assert llm_client.generate_response.call_count == 0 + + +@pytest.mark.asyncio +async def test_extract_attributes_with_callback_generate_summary(): + """Test that summary is regenerated when callback returns True.""" + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={'summary': 'New generated summary', 'attributes': {}} + ) + + node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary') + episode = _make_episode() + + # Callback that always returns True (generate summary) + async def generate_summary_filter(node: EntityNode) -> bool: + return True + + result = await extract_attributes_from_node( + llm_client, + node, + episode=episode, + previous_episodes=[], + entity_type=None, + ensure_ascii=False, + should_summarize_node=generate_summary_filter, + ) + + # Summary should be updated + assert result.summary == 'New generated summary' + # LLM should have been called for summary + assert llm_client.generate_response.call_count == 1 + + +@pytest.mark.asyncio +async def test_extract_attributes_with_selective_callback(): + """Test callback that selectively skips summaries based on node properties.""" + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={'summary': 'Generated summary', 'attributes': {}} + ) + + user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old') + topic_node = EntityNode( + name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old' + ) + + episode = _make_episode() + + # Callback that skips User nodes but generates for others + async def selective_filter(node: EntityNode) -> bool: + return 'User' not in node.labels + + result_user = await extract_attributes_from_node( + llm_client, + user_node, + episode=episode, + previous_episodes=[], + entity_type=None, + ensure_ascii=False, + should_summarize_node=selective_filter, + ) + + result_topic = await extract_attributes_from_node( + llm_client, + topic_node, + episode=episode, + previous_episodes=[], + entity_type=None, + ensure_ascii=False, + should_summarize_node=selective_filter, + ) + + # User summary should remain unchanged + assert result_user.summary == 'Old' + # Topic summary should be generated + assert result_topic.summary == 'Generated summary' + # LLM should have been called only once (for topic) + assert llm_client.generate_response.call_count == 1 + + +@pytest.mark.asyncio +async def test_extract_attributes_from_nodes_with_callback(): + """Test that callback is properly passed through extract_attributes_from_nodes.""" + clients, _ = _make_clients() + clients.llm_client.generate_response = AsyncMock( + return_value={'summary': 'New summary', 'attributes': {}} + ) + clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3]) + clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + + node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1') + node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2') + + episode = _make_episode() + + call_tracker = [] + + # Callback that tracks which nodes it's called with + async def tracking_filter(node: EntityNode) -> bool: + call_tracker.append(node.name) + return 'User' not in node.labels + + results = await extract_attributes_from_nodes( + clients, + [node1, node2], + episode=episode, + previous_episodes=[], + entity_types=None, + should_summarize_node=tracking_filter, + ) + + # Callback should have been called for both nodes + assert len(call_tracker) == 2 + assert 'Node1' in call_tracker + assert 'Node2' in call_tracker + + # Node1 (User) should keep old summary, Node2 (Topic) should get new summary + node1_result = next(n for n in results if n.name == 'Node1') + node2_result = next(n for n in results if n.name == 'Node2') + + assert node1_result.summary == 'Old1' + assert node2_result.summary == 'New summary'