Fix temporal invalidation unit tests (#23)
* wip * wip * wip * fix: Linter errors * fix formatting * chore: fix ruff * fix: Duplication * chore: Fix unit tests for temporal invalidation * attempt to fix unit tests * fix: format --------- Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
parent
72dfa3c1e3
commit
8a55f48f5e
3 changed files with 107 additions and 13 deletions
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import LiteralString
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from core.nodes import EpisodicNode
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from datetime import datetime, timedelta
|
|||
import pytest
|
||||
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
prepare_edges_for_invalidation,
|
||||
prepare_invalidation_context,
|
||||
|
|
@ -114,7 +114,6 @@ def test_prepare_edges_for_invalidation_missing_nodes():
|
|||
|
||||
|
||||
def test_prepare_invalidation_context():
|
||||
# Create test data
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
|
|
@ -148,15 +147,49 @@ def test_prepare_invalidation_context():
|
|||
existing_edges = [existing_edge]
|
||||
new_edges = [new_edge]
|
||||
|
||||
# Create a current episode and previous episodes
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
content='This is the current episode content.',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source_description='Test episode for unit testing',
|
||||
)
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name='Previous Episode 1',
|
||||
content='This is the content of previous episode 1.',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source='test',
|
||||
source_description='Test previous episode 1 for unit testing',
|
||||
),
|
||||
EpisodicNode(
|
||||
name='Previous Episode 2',
|
||||
content='This is the content of previous episode 2.',
|
||||
created_at=now - timedelta(days=2),
|
||||
valid_at=now - timedelta(days=2),
|
||||
source='test',
|
||||
source_description='Test previous episode 2 for unit testing',
|
||||
),
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = prepare_invalidation_context(existing_edges, new_edges)
|
||||
result = prepare_invalidation_context(
|
||||
existing_edges, new_edges, current_episode, previous_episodes
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, dict)
|
||||
assert 'existing_edges' in result
|
||||
assert 'new_edges' in result
|
||||
assert 'current_episode' in result
|
||||
assert 'previous_episodes' in result
|
||||
assert len(result['existing_edges']) == 1
|
||||
assert len(result['new_edges']) == 1
|
||||
assert result['current_episode'] == current_episode.content
|
||||
assert len(result['previous_episodes']) == 2
|
||||
|
||||
# Check the format of the existing edge
|
||||
existing_edge_str = result['existing_edges'][0]
|
||||
|
|
@ -176,12 +209,25 @@ def test_prepare_invalidation_context():
|
|||
|
||||
|
||||
def test_prepare_invalidation_context_empty_input():
|
||||
result = prepare_invalidation_context([], [])
|
||||
now = datetime.now()
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
content='Empty episode',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source_description='Test empty episode for unit testing',
|
||||
)
|
||||
result = prepare_invalidation_context([], [], current_episode, [])
|
||||
assert isinstance(result, dict)
|
||||
assert 'existing_edges' in result
|
||||
assert 'new_edges' in result
|
||||
assert 'current_episode' in result
|
||||
assert 'previous_episodes' in result
|
||||
assert len(result['existing_edges']) == 0
|
||||
assert len(result['new_edges']) == 0
|
||||
assert result['current_episode'] == current_episode.content
|
||||
assert len(result['previous_episodes']) == 0
|
||||
|
||||
|
||||
def test_prepare_invalidation_context_sorting():
|
||||
|
|
@ -215,13 +261,36 @@ def test_prepare_invalidation_context_sorting():
|
|||
# Prepare test input
|
||||
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
||||
|
||||
# Create a current episode and previous episodes
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
content='This is the current episode content.',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source_description='Test episode for unit testing',
|
||||
)
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name='Previous Episode',
|
||||
content='This is the content of a previous episode.',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source='test',
|
||||
source_description='Test previous episode for unit testing',
|
||||
),
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = prepare_invalidation_context(existing_edges, [])
|
||||
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
|
||||
|
||||
# Assert the result
|
||||
assert len(result['existing_edges']) == 2
|
||||
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
|
||||
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
|
||||
assert result['current_episode'] == current_episode.content
|
||||
assert len(result['previous_episodes']) == 1
|
||||
assert result['previous_episodes'][0] == previous_episodes[0].content
|
||||
|
||||
|
||||
# Run the tests
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from dotenv import load_dotenv
|
|||
|
||||
from core.edges import EntityEdge
|
||||
from core.llm_client import LLMConfig, OpenAIClient
|
||||
from core.nodes import EntityNode
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
invalidate_edges,
|
||||
)
|
||||
|
|
@ -24,7 +24,6 @@ def setup_llm_client():
|
|||
)
|
||||
|
||||
|
||||
# Helper function to create test data
|
||||
def create_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
|
|
@ -53,15 +52,39 @@ def create_test_data():
|
|||
existing_edge = (node1, edge1, node2)
|
||||
new_edge = (node1, edge2, node2)
|
||||
|
||||
return existing_edge, new_edge
|
||||
# Create current episode
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
content='Alice now dislikes Bob',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source_description='Test episode for unit testing',
|
||||
)
|
||||
|
||||
# Create previous episodes
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name='Previous Episode',
|
||||
content='Alice liked Bob',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source='test',
|
||||
source_description='Test previous episode for unit testing',
|
||||
)
|
||||
]
|
||||
|
||||
return existing_edge, new_edge, current_episode, previous_episodes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges():
|
||||
existing_edge, new_edge = create_test_data()
|
||||
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
|
||||
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [new_edge])
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == existing_edge[1].uuid
|
||||
|
|
@ -71,9 +94,11 @@ async def test_invalidate_edges():
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_no_invalidation():
|
||||
existing_edge, _ = create_test_data()
|
||||
existing_edge, _, current_episode, previous_episodes = create_test_data()
|
||||
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [])
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 0
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue