Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
d6b7f8cc83
3 changed files with 202 additions and 8 deletions
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -55,6 +56,8 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
|
||||
|
||||
|
||||
async def extract_nodes_reflexion(
|
||||
llm_client: LLMClient,
|
||||
|
|
@ -402,6 +405,7 @@ async def extract_attributes_from_nodes(
|
|||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
should_summarize_node: NodeSummaryFilter | None = None,
|
||||
) -> list[EntityNode]:
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
|
@ -418,6 +422,7 @@ async def extract_attributes_from_nodes(
|
|||
else None
|
||||
),
|
||||
clients.ensure_ascii,
|
||||
should_summarize_node,
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
|
@ -435,6 +440,7 @@ async def extract_attributes_from_node(
|
|||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_type: type[BaseModel] | None = None,
|
||||
ensure_ascii: bool = False,
|
||||
should_summarize_node: NodeSummaryFilter | None = None,
|
||||
) -> EntityNode:
|
||||
node_context: dict[str, Any] = {
|
||||
'name': node.name,
|
||||
|
|
@ -477,16 +483,22 @@ async def extract_attributes_from_node(
|
|||
else {}
|
||||
)
|
||||
|
||||
summary_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
# Determine if summary should be generated
|
||||
generate_summary = True
|
||||
if should_summarize_node is not None:
|
||||
generate_summary = await should_summarize_node(node)
|
||||
|
||||
# Conditionally generate summary
|
||||
if generate_summary:
|
||||
summary_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
node.summary = summary_response.get('summary', '')
|
||||
|
||||
if has_entity_attributes and entity_type is not None:
|
||||
entity_type(**llm_response)
|
||||
|
||||
node.summary = summary_response.get('summary', '')
|
||||
node_attributes = {key: value for key, value in llm_response.items()}
|
||||
|
||||
node.attributes.update(node_attributes)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.21.0pre9"
|
||||
version = "0.21.0pre10"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ from graphiti_core.utils.maintenance.dedup_helpers import (
|
|||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
_collect_candidate_nodes,
|
||||
_resolve_with_llm,
|
||||
extract_attributes_from_node,
|
||||
extract_attributes_from_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
|
||||
|
|
@ -477,3 +479,183 @@ async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monk
|
|||
assert state.resolved_nodes[0] == extracted
|
||||
assert state.uuid_map[extracted.uuid] == extracted.uuid
|
||||
assert state.duplicate_pairs == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_without_callback_generates_summary():
|
||||
"""Test that summary is generated when no callback is provided (default behavior)."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'Generated summary', 'attributes': {}}
|
||||
)
|
||||
|
||||
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
|
||||
episode = _make_episode()
|
||||
|
||||
result = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=None, # No callback provided
|
||||
)
|
||||
|
||||
# Summary should be generated
|
||||
assert result.summary == 'Generated summary'
|
||||
# LLM should have been called for summary
|
||||
assert llm_client.generate_response.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_with_callback_skip_summary():
|
||||
"""Test that summary is NOT regenerated when callback returns False."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'This should not be used', 'attributes': {}}
|
||||
)
|
||||
|
||||
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
|
||||
episode = _make_episode()
|
||||
|
||||
# Callback that always returns False (skip summary generation)
|
||||
async def skip_summary_filter(node: EntityNode) -> bool:
|
||||
return False
|
||||
|
||||
result = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=skip_summary_filter,
|
||||
)
|
||||
|
||||
# Summary should remain unchanged
|
||||
assert result.summary == 'Old summary'
|
||||
# LLM should NOT have been called for summary
|
||||
assert llm_client.generate_response.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_with_callback_generate_summary():
|
||||
"""Test that summary is regenerated when callback returns True."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'New generated summary', 'attributes': {}}
|
||||
)
|
||||
|
||||
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
|
||||
episode = _make_episode()
|
||||
|
||||
# Callback that always returns True (generate summary)
|
||||
async def generate_summary_filter(node: EntityNode) -> bool:
|
||||
return True
|
||||
|
||||
result = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=generate_summary_filter,
|
||||
)
|
||||
|
||||
# Summary should be updated
|
||||
assert result.summary == 'New generated summary'
|
||||
# LLM should have been called for summary
|
||||
assert llm_client.generate_response.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_with_selective_callback():
|
||||
"""Test callback that selectively skips summaries based on node properties."""
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'Generated summary', 'attributes': {}}
|
||||
)
|
||||
|
||||
user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
|
||||
topic_node = EntityNode(
|
||||
name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
|
||||
)
|
||||
|
||||
episode = _make_episode()
|
||||
|
||||
# Callback that skips User nodes but generates for others
|
||||
async def selective_filter(node: EntityNode) -> bool:
|
||||
return 'User' not in node.labels
|
||||
|
||||
result_user = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
user_node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=selective_filter,
|
||||
)
|
||||
|
||||
result_topic = await extract_attributes_from_node(
|
||||
llm_client,
|
||||
topic_node,
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_type=None,
|
||||
ensure_ascii=False,
|
||||
should_summarize_node=selective_filter,
|
||||
)
|
||||
|
||||
# User summary should remain unchanged
|
||||
assert result_user.summary == 'Old'
|
||||
# Topic summary should be generated
|
||||
assert result_topic.summary == 'Generated summary'
|
||||
# LLM should have been called only once (for topic)
|
||||
assert llm_client.generate_response.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_attributes_from_nodes_with_callback():
|
||||
"""Test that callback is properly passed through extract_attributes_from_nodes."""
|
||||
clients, _ = _make_clients()
|
||||
clients.llm_client.generate_response = AsyncMock(
|
||||
return_value={'summary': 'New summary', 'attributes': {}}
|
||||
)
|
||||
clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
|
||||
|
||||
node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
|
||||
node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
|
||||
|
||||
episode = _make_episode()
|
||||
|
||||
call_tracker = []
|
||||
|
||||
# Callback that tracks which nodes it's called with
|
||||
async def tracking_filter(node: EntityNode) -> bool:
|
||||
call_tracker.append(node.name)
|
||||
return 'User' not in node.labels
|
||||
|
||||
results = await extract_attributes_from_nodes(
|
||||
clients,
|
||||
[node1, node2],
|
||||
episode=episode,
|
||||
previous_episodes=[],
|
||||
entity_types=None,
|
||||
should_summarize_node=tracking_filter,
|
||||
)
|
||||
|
||||
# Callback should have been called for both nodes
|
||||
assert len(call_tracker) == 2
|
||||
assert 'Node1' in call_tracker
|
||||
assert 'Node2' in call_tracker
|
||||
|
||||
# Node1 (User) should keep old summary, Node2 (Topic) should get new summary
|
||||
node1_result = next(n for n in results if n.name == 'Node1')
|
||||
node2_result = next(n for n in results if n.name == 'Node2')
|
||||
|
||||
assert node1_result.summary == 'Old1'
|
||||
assert node2_result.summary == 'New summary'
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue