Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Brandt Weary 2025-10-02 04:59:07 -07:00
commit d6b7f8cc83
3 changed files with 202 additions and 8 deletions

View file

@ -15,6 +15,7 @@ limitations under the License.
"""
import logging
from collections.abc import Awaitable, Callable
from time import time
from typing import Any
@ -55,6 +56,8 @@ from graphiti_core.utils.maintenance.edge_operations import (
logger = logging.getLogger(__name__)
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
async def extract_nodes_reflexion(
llm_client: LLMClient,
@ -402,6 +405,7 @@ async def extract_attributes_from_nodes(
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
should_summarize_node: NodeSummaryFilter | None = None,
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
@ -418,6 +422,7 @@ async def extract_attributes_from_nodes(
else None
),
clients.ensure_ascii,
should_summarize_node,
)
for node in nodes
]
@ -435,6 +440,7 @@ async def extract_attributes_from_node(
previous_episodes: list[EpisodicNode] | None = None,
entity_type: type[BaseModel] | None = None,
ensure_ascii: bool = False,
should_summarize_node: NodeSummaryFilter | None = None,
) -> EntityNode:
node_context: dict[str, Any] = {
'name': node.name,
@ -477,16 +483,22 @@ async def extract_attributes_from_node(
else {}
)
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
)
# Determine if summary should be generated
generate_summary = True
if should_summarize_node is not None:
generate_summary = await should_summarize_node(node)
# Conditionally generate summary
if generate_summary:
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
)
node.summary = summary_response.get('summary', '')
if has_entity_attributes and entity_type is not None:
entity_type(**llm_response)
node.summary = summary_response.get('summary', '')
node_attributes = {key: value for key, value in llm_response.items()}
node.attributes.update(node_attributes)

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.21.0pre9"
version = "0.21.0pre10"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -27,6 +27,8 @@ from graphiti_core.utils.maintenance.dedup_helpers import (
from graphiti_core.utils.maintenance.node_operations import (
_collect_candidate_nodes,
_resolve_with_llm,
extract_attributes_from_node,
extract_attributes_from_nodes,
resolve_extracted_nodes,
)
@ -477,3 +479,183 @@ async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monk
assert state.resolved_nodes[0] == extracted
assert state.uuid_map[extracted.uuid] == extracted.uuid
assert state.duplicate_pairs == []
@pytest.mark.asyncio
async def test_extract_attributes_without_callback_generates_summary():
"""Test that summary is generated when no callback is provided (default behavior)."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=None, # No callback provided
)
# Summary should be generated
assert result.summary == 'Generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_skip_summary():
"""Test that summary is NOT regenerated when callback returns False."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'This should not be used', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns False (skip summary generation)
async def skip_summary_filter(node: EntityNode) -> bool:
return False
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=skip_summary_filter,
)
# Summary should remain unchanged
assert result.summary == 'Old summary'
# LLM should NOT have been called for summary
assert llm_client.generate_response.call_count == 0
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_generate_summary():
"""Test that summary is regenerated when callback returns True."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'New generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns True (generate summary)
async def generate_summary_filter(node: EntityNode) -> bool:
return True
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=generate_summary_filter,
)
# Summary should be updated
assert result.summary == 'New generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_selective_callback():
"""Test callback that selectively skips summaries based on node properties."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
topic_node = EntityNode(
name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
)
episode = _make_episode()
# Callback that skips User nodes but generates for others
async def selective_filter(node: EntityNode) -> bool:
return 'User' not in node.labels
result_user = await extract_attributes_from_node(
llm_client,
user_node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=selective_filter,
)
result_topic = await extract_attributes_from_node(
llm_client,
topic_node,
episode=episode,
previous_episodes=[],
entity_type=None,
ensure_ascii=False,
should_summarize_node=selective_filter,
)
# User summary should remain unchanged
assert result_user.summary == 'Old'
# Topic summary should be generated
assert result_topic.summary == 'Generated summary'
# LLM should have been called only once (for topic)
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_from_nodes_with_callback():
"""Test that callback is properly passed through extract_attributes_from_nodes."""
clients, _ = _make_clients()
clients.llm_client.generate_response = AsyncMock(
return_value={'summary': 'New summary', 'attributes': {}}
)
clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
episode = _make_episode()
call_tracker = []
# Callback that tracks which nodes it's called with
async def tracking_filter(node: EntityNode) -> bool:
call_tracker.append(node.name)
return 'User' not in node.labels
results = await extract_attributes_from_nodes(
clients,
[node1, node2],
episode=episode,
previous_episodes=[],
entity_types=None,
should_summarize_node=tracking_filter,
)
# Callback should have been called for both nodes
assert len(call_tracker) == 2
assert 'Node1' in call_tracker
assert 'Node2' in call_tracker
# Node1 (User) should keep old summary, Node2 (Topic) should get new summary
node1_result = next(n for n in results if n.name == 'Node1')
node2_result = next(n for n in results if n.name == 'Node2')
assert node1_result.summary == 'Old1'
assert node2_result.summary == 'New summary'