from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock import pytest from graphiti_core.edges import EntityEdge from graphiti_core.nodes import EpisodicNode from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge @pytest.fixture def mock_llm_client(): return MagicMock() @pytest.fixture def mock_extracted_edge(): return EntityEdge( source_node_uuid='source_uuid', target_node_uuid='target_uuid', name='test_edge', group_id='group_1', fact='Test fact', episodes=['episode_1'], created_at=datetime.now(timezone.utc), valid_at=None, invalid_at=None, ) @pytest.fixture def mock_related_edges(): return [ EntityEdge( source_node_uuid='source_uuid_2', target_node_uuid='target_uuid_2', name='related_edge', group_id='group_1', fact='Related fact', episodes=['episode_2'], created_at=datetime.now(timezone.utc) - timedelta(days=1), valid_at=datetime.now(timezone.utc) - timedelta(days=1), invalid_at=None, ) ] @pytest.fixture def mock_existing_edges(): return [ EntityEdge( source_node_uuid='source_uuid_3', target_node_uuid='target_uuid_3', name='existing_edge', group_id='group_1', fact='Existing fact', episodes=['episode_3'], created_at=datetime.now(timezone.utc) - timedelta(days=2), valid_at=datetime.now(timezone.utc) - timedelta(days=2), invalid_at=None, ) ] @pytest.fixture def mock_current_episode(): return EpisodicNode( uuid='episode_1', content='Current episode content', valid_at=datetime.now(timezone.utc), name='Current Episode', group_id='group_1', source='message', source_description='Test source description', ) @pytest.fixture def mock_previous_episodes(): return [ EpisodicNode( uuid='episode_2', content='Previous episode content', valid_at=datetime.now(timezone.utc) - timedelta(days=1), name='Previous Episode', group_id='group_1', source='message', source_description='Test source description', ) ] # Run the tests if __name__ == '__main__': pytest.main([__file__]) @pytest.mark.asyncio async def test_resolve_extracted_edge_exact_fact_short_circuit( mock_llm_client, mock_existing_edges, mock_current_episode, ): extracted = EntityEdge( source_node_uuid='source_uuid', target_node_uuid='target_uuid', name='test_edge', group_id='group_1', fact='Related fact', episodes=['episode_1'], created_at=datetime.now(timezone.utc), valid_at=None, invalid_at=None, ) related_edges = [ EntityEdge( source_node_uuid='source_uuid', target_node_uuid='target_uuid', name='related_edge', group_id='group_1', fact=' related FACT ', episodes=['episode_2'], created_at=datetime.now(timezone.utc) - timedelta(days=1), valid_at=None, invalid_at=None, ) ] resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge( mock_llm_client, extracted, related_edges, mock_existing_edges, mock_current_episode, edge_types=None, ensure_ascii=True, ) assert resolved_edge is related_edges[0] assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1 assert duplicate_edges == [] assert invalidated == [] mock_llm_client.generate_response.assert_not_called()