diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py index 9ee91a64..368ddb5f 100644 --- a/core/utils/maintenance/graph_data_operations.py +++ b/core/utils/maintenance/graph_data_operations.py @@ -1,9 +1,9 @@ import asyncio import logging from datetime import datetime, timezone -from typing import LiteralString from neo4j import AsyncDriver +from typing_extensions import LiteralString from core.nodes import EpisodicNode diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py index 9fbcf2af..7a3b94cc 100644 --- a/tests/utils/maintenance/test_temporal_operations.py +++ b/tests/utils/maintenance/test_temporal_operations.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta import pytest from core.edges import EntityEdge -from core.nodes import EntityNode +from core.nodes import EntityNode, EpisodicNode from core.utils.maintenance.temporal_operations import ( prepare_edges_for_invalidation, prepare_invalidation_context, @@ -114,7 +114,6 @@ def test_prepare_edges_for_invalidation_missing_nodes(): def test_prepare_invalidation_context(): - # Create test data now = datetime.now() # Create nodes @@ -148,15 +147,49 @@ def test_prepare_invalidation_context(): existing_edges = [existing_edge] new_edges = [new_edge] + # Create a current episode and previous episodes + current_episode = EpisodicNode( + name='Current Episode', + content='This is the current episode content.', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + previous_episodes = [ + EpisodicNode( + name='Previous Episode 1', + content='This is the content of previous episode 1.', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode 1 for unit testing', + ), + EpisodicNode( + name='Previous Episode 2', + content='This is the content of previous episode 2.', + created_at=now - timedelta(days=2), + valid_at=now - timedelta(days=2), + source='test', + source_description='Test previous episode 2 for unit testing', + ), + ] + # Call the function - result = prepare_invalidation_context(existing_edges, new_edges) + result = prepare_invalidation_context( + existing_edges, new_edges, current_episode, previous_episodes + ) # Assert the result assert isinstance(result, dict) assert 'existing_edges' in result assert 'new_edges' in result + assert 'current_episode' in result + assert 'previous_episodes' in result assert len(result['existing_edges']) == 1 assert len(result['new_edges']) == 1 + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 2 # Check the format of the existing edge existing_edge_str = result['existing_edges'][0] @@ -176,12 +209,25 @@ def test_prepare_invalidation_context(): def test_prepare_invalidation_context_empty_input(): - result = prepare_invalidation_context([], []) + now = datetime.now() + current_episode = EpisodicNode( + name='Current Episode', + content='Empty episode', + created_at=now, + valid_at=now, + source='test', + source_description='Test empty episode for unit testing', + ) + result = prepare_invalidation_context([], [], current_episode, []) assert isinstance(result, dict) assert 'existing_edges' in result assert 'new_edges' in result + assert 'current_episode' in result + assert 'previous_episodes' in result assert len(result['existing_edges']) == 0 assert len(result['new_edges']) == 0 + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 0 def test_prepare_invalidation_context_sorting(): @@ -215,13 +261,36 @@ def test_prepare_invalidation_context_sorting(): # Prepare test input existing_edges = [edge_with_nodes1, edge_with_nodes2] + # Create a current episode and previous episodes + current_episode = EpisodicNode( + name='Current Episode', + content='This is the current episode content.', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + previous_episodes = [ + EpisodicNode( + name='Previous Episode', + content='This is the content of a previous episode.', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode for unit testing', + ), + ] + # Call the function - result = prepare_invalidation_context(existing_edges, []) + result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes) # Assert the result assert len(result['existing_edges']) == 2 assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first assert edge1.uuid in result['existing_edges'][1] # The older edge should be second + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 1 + assert result['previous_episodes'][0] == previous_episodes[0].content # Run the tests diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 6ea46d4d..37baf0cc 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from core.edges import EntityEdge from core.llm_client import LLMConfig, OpenAIClient -from core.nodes import EntityNode +from core.nodes import EntityNode, EpisodicNode from core.utils.maintenance.temporal_operations import ( invalidate_edges, ) @@ -24,7 +24,6 @@ def setup_llm_client(): ) -# Helper function to create test data def create_test_data(): now = datetime.now() @@ -53,15 +52,39 @@ def create_test_data(): existing_edge = (node1, edge1, node2) new_edge = (node1, edge2, node2) - return existing_edge, new_edge + # Create current episode + current_episode = EpisodicNode( + name='Current Episode', + content='Alice now dislikes Bob', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + + # Create previous episodes + previous_episodes = [ + EpisodicNode( + name='Previous Episode', + content='Alice liked Bob', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode for unit testing', + ) + ] + + return existing_edge, new_edge, current_episode, previous_episodes @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges(): - existing_edge, new_edge = create_test_data() + existing_edge, new_edge, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [new_edge]) + invalidated_edges = await invalidate_edges( + setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes + ) assert len(invalidated_edges) == 1 assert invalidated_edges[0].uuid == existing_edge[1].uuid @@ -71,9 +94,11 @@ async def test_invalidate_edges(): @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_no_invalidation(): - existing_edge, _ = create_test_data() + existing_edge, _, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], []) + invalidated_edges = await invalidate_edges( + setup_llm_client(), [existing_edge], [], current_episode, previous_episodes + ) assert len(invalidated_edges) == 0