- Remove unnecessary stubs for opensearchpy module. - Format return values in llm_client.generate_response for consistency. - Enhance readability by ensuring proper indentation and structure in test cases. This refactor improves the clarity and maintainability of the test suite for edge operations.
436 lines
12 KiB
Python
436 lines
12 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
from graphiti_core.search.search_config import SearchResults
|
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
DEFAULT_EDGE_NAME,
|
|
resolve_extracted_edge,
|
|
resolve_extracted_edges,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_client():
|
|
client = MagicMock()
|
|
client.generate_response = AsyncMock()
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_extracted_edge():
|
|
return EntityEdge(
|
|
source_node_uuid='source_uuid',
|
|
target_node_uuid='target_uuid',
|
|
name='test_edge',
|
|
group_id='group_1',
|
|
fact='Test fact',
|
|
episodes=['episode_1'],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_related_edges():
|
|
return [
|
|
EntityEdge(
|
|
source_node_uuid='source_uuid_2',
|
|
target_node_uuid='target_uuid_2',
|
|
name='related_edge',
|
|
group_id='group_1',
|
|
fact='Related fact',
|
|
episodes=['episode_2'],
|
|
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
|
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
|
|
invalid_at=None,
|
|
)
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_existing_edges():
|
|
return [
|
|
EntityEdge(
|
|
source_node_uuid='source_uuid_3',
|
|
target_node_uuid='target_uuid_3',
|
|
name='existing_edge',
|
|
group_id='group_1',
|
|
fact='Existing fact',
|
|
episodes=['episode_3'],
|
|
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
|
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
|
|
invalid_at=None,
|
|
)
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_current_episode():
|
|
return EpisodicNode(
|
|
uuid='episode_1',
|
|
content='Current episode content',
|
|
valid_at=datetime.now(timezone.utc),
|
|
name='Current Episode',
|
|
group_id='group_1',
|
|
source='message',
|
|
source_description='Test source description',
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_previous_episodes():
|
|
return [
|
|
EpisodicNode(
|
|
uuid='episode_2',
|
|
content='Previous episode content',
|
|
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
|
|
name='Previous Episode',
|
|
group_id='group_1',
|
|
source='message',
|
|
source_description='Test source description',
|
|
)
|
|
]
|
|
|
|
|
|
# Run the tests
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_extracted_edge_exact_fact_short_circuit(
|
|
mock_llm_client,
|
|
mock_existing_edges,
|
|
mock_current_episode,
|
|
):
|
|
extracted = EntityEdge(
|
|
source_node_uuid='source_uuid',
|
|
target_node_uuid='target_uuid',
|
|
name='test_edge',
|
|
group_id='group_1',
|
|
fact='Related fact',
|
|
episodes=['episode_1'],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
related_edges = [
|
|
EntityEdge(
|
|
source_node_uuid='source_uuid',
|
|
target_node_uuid='target_uuid',
|
|
name='related_edge',
|
|
group_id='group_1',
|
|
fact=' related FACT ',
|
|
episodes=['episode_2'],
|
|
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
]
|
|
|
|
resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge(
|
|
mock_llm_client,
|
|
extracted,
|
|
related_edges,
|
|
mock_existing_edges,
|
|
mock_current_episode,
|
|
edge_type_candidates=None,
|
|
ensure_ascii=True,
|
|
)
|
|
|
|
assert resolved_edge is related_edges[0]
|
|
assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1
|
|
assert duplicate_edges == []
|
|
assert invalidated == []
|
|
mock_llm_client.generate_response.assert_not_called()
|
|
|
|
|
|
class OccurredAtEdge(BaseModel):
|
|
"""Edge model stub for OCCURRED_AT."""
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
|
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
|
|
|
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
|
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
|
|
|
async def immediate_gather(*aws, max_coroutines=None):
|
|
return [await aw for aw in aws]
|
|
|
|
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
|
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
|
|
|
llm_client = MagicMock()
|
|
llm_client.generate_response = AsyncMock(
|
|
return_value={
|
|
'duplicate_facts': [],
|
|
'contradicted_facts': [],
|
|
'fact_type': 'DEFAULT',
|
|
}
|
|
)
|
|
|
|
clients = SimpleNamespace(
|
|
driver=MagicMock(),
|
|
llm_client=llm_client,
|
|
embedder=MagicMock(),
|
|
cross_encoder=MagicMock(),
|
|
ensure_ascii=True,
|
|
)
|
|
|
|
source_node = EntityNode(
|
|
uuid='source_uuid',
|
|
name='Document Node',
|
|
group_id='group_1',
|
|
labels=['Document'],
|
|
)
|
|
target_node = EntityNode(
|
|
uuid='target_uuid',
|
|
name='Topic Node',
|
|
group_id='group_1',
|
|
labels=['Topic'],
|
|
)
|
|
|
|
extracted_edge = EntityEdge(
|
|
source_node_uuid=source_node.uuid,
|
|
target_node_uuid=target_node.uuid,
|
|
name='OCCURRED_AT',
|
|
group_id='group_1',
|
|
fact='Document occurred at somewhere',
|
|
episodes=[],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
episode = EpisodicNode(
|
|
uuid='episode_uuid',
|
|
name='Episode',
|
|
group_id='group_1',
|
|
source='message',
|
|
source_description='desc',
|
|
content='Episode content',
|
|
valid_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
edge_types = {'OCCURRED_AT': OccurredAtEdge}
|
|
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
|
|
|
|
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
clients,
|
|
[extracted_edge],
|
|
episode,
|
|
[source_node, target_node],
|
|
edge_types,
|
|
edge_type_map,
|
|
)
|
|
|
|
assert resolved_edges[0].name == DEFAULT_EDGE_NAME
|
|
assert invalidated_edges == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
|
|
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
|
|
|
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
|
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
|
|
|
async def immediate_gather(*aws, max_coroutines=None):
|
|
return [await aw for aw in aws]
|
|
|
|
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
|
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
|
|
|
llm_client = MagicMock()
|
|
llm_client.generate_response = AsyncMock(
|
|
return_value={
|
|
'duplicate_facts': [],
|
|
'contradicted_facts': [],
|
|
'fact_type': 'DEFAULT',
|
|
}
|
|
)
|
|
|
|
clients = SimpleNamespace(
|
|
driver=MagicMock(),
|
|
llm_client=llm_client,
|
|
embedder=MagicMock(),
|
|
cross_encoder=MagicMock(),
|
|
ensure_ascii=True,
|
|
)
|
|
|
|
source_node = EntityNode(
|
|
uuid='source_uuid',
|
|
name='User Node',
|
|
group_id='group_1',
|
|
labels=['User'],
|
|
)
|
|
target_node = EntityNode(
|
|
uuid='target_uuid',
|
|
name='Topic Node',
|
|
group_id='group_1',
|
|
labels=['Topic'],
|
|
)
|
|
|
|
extracted_edge = EntityEdge(
|
|
source_node_uuid=source_node.uuid,
|
|
target_node_uuid=target_node.uuid,
|
|
name='INTERACTED_WITH',
|
|
group_id='group_1',
|
|
fact='User interacted with topic',
|
|
episodes=[],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
episode = EpisodicNode(
|
|
uuid='episode_uuid',
|
|
name='Episode',
|
|
group_id='group_1',
|
|
source='message',
|
|
source_description='desc',
|
|
content='Episode content',
|
|
valid_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
edge_types = {'OCCURRED_AT': OccurredAtEdge}
|
|
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
|
|
|
|
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
clients,
|
|
[extracted_edge],
|
|
episode,
|
|
[source_node, target_node],
|
|
edge_types,
|
|
edge_type_map,
|
|
)
|
|
|
|
assert resolved_edges[0].name == 'INTERACTED_WITH'
|
|
assert invalidated_edges == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
|
|
mock_llm_client.generate_response.return_value = {
|
|
'duplicate_facts': [],
|
|
'contradicted_facts': [],
|
|
'fact_type': 'OCCURRED_AT',
|
|
}
|
|
|
|
extracted_edge = EntityEdge(
|
|
source_node_uuid='source_uuid',
|
|
target_node_uuid='target_uuid',
|
|
name='OCCURRED_AT',
|
|
group_id='group_1',
|
|
fact='Document occurred at somewhere',
|
|
episodes=[],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
episode = EpisodicNode(
|
|
uuid='episode_uuid',
|
|
name='Episode',
|
|
group_id='group_1',
|
|
source='message',
|
|
source_description='desc',
|
|
content='Episode content',
|
|
valid_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
related_edge = EntityEdge(
|
|
source_node_uuid='alt_source',
|
|
target_node_uuid='alt_target',
|
|
name='OTHER',
|
|
group_id='group_1',
|
|
fact='Different fact',
|
|
episodes=[],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
|
|
mock_llm_client,
|
|
extracted_edge,
|
|
[related_edge],
|
|
[],
|
|
episode,
|
|
edge_type_candidates={},
|
|
custom_edge_type_names={'OCCURRED_AT'},
|
|
ensure_ascii=True,
|
|
)
|
|
|
|
assert resolved_edge.name == DEFAULT_EDGE_NAME
|
|
assert duplicates == []
|
|
assert invalidated == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
|
|
mock_llm_client.generate_response.return_value = {
|
|
'duplicate_facts': [],
|
|
'contradicted_facts': [],
|
|
'fact_type': 'INTERACTED_WITH',
|
|
}
|
|
|
|
extracted_edge = EntityEdge(
|
|
source_node_uuid='source_uuid',
|
|
target_node_uuid='target_uuid',
|
|
name='DEFAULT',
|
|
group_id='group_1',
|
|
fact='User interacted with topic',
|
|
episodes=[],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
episode = EpisodicNode(
|
|
uuid='episode_uuid',
|
|
name='Episode',
|
|
group_id='group_1',
|
|
source='message',
|
|
source_description='desc',
|
|
content='Episode content',
|
|
valid_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
related_edge = EntityEdge(
|
|
source_node_uuid='source_uuid',
|
|
target_node_uuid='target_uuid',
|
|
name='DEFAULT',
|
|
group_id='group_1',
|
|
fact='User mentioned a topic',
|
|
episodes=[],
|
|
created_at=datetime.now(timezone.utc),
|
|
valid_at=None,
|
|
invalid_at=None,
|
|
)
|
|
|
|
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
|
|
mock_llm_client,
|
|
extracted_edge,
|
|
[related_edge],
|
|
[],
|
|
episode,
|
|
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
|
|
custom_edge_type_names={'OCCURRED_AT'},
|
|
ensure_ascii=True,
|
|
)
|
|
|
|
assert resolved_edge.name == 'INTERACTED_WITH'
|
|
assert resolved_edge.attributes == {}
|
|
assert duplicates == []
|
|
assert invalidated == []
|