Fix temporal invalidation unit tests (#23)
* wip * wip * wip * fix: Linter errors * fix formatting * chore: fix ruff * fix: Duplication * chore: Fix unit tests for temporal invalidation * attempt to fix unit tests * fix: format --------- Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
parent
72dfa3c1e3
commit
8a55f48f5e
3 changed files with 107 additions and 13 deletions
|
|
@ -1,9 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import LiteralString
|
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from core.nodes import EpisodicNode
|
from core.nodes import EpisodicNode
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from datetime import datetime, timedelta
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.edges import EntityEdge
|
from core.edges import EntityEdge
|
||||||
from core.nodes import EntityNode
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
from core.utils.maintenance.temporal_operations import (
|
from core.utils.maintenance.temporal_operations import (
|
||||||
prepare_edges_for_invalidation,
|
prepare_edges_for_invalidation,
|
||||||
prepare_invalidation_context,
|
prepare_invalidation_context,
|
||||||
|
|
@ -114,7 +114,6 @@ def test_prepare_edges_for_invalidation_missing_nodes():
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_invalidation_context():
|
def test_prepare_invalidation_context():
|
||||||
# Create test data
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
# Create nodes
|
# Create nodes
|
||||||
|
|
@ -148,15 +147,49 @@ def test_prepare_invalidation_context():
|
||||||
existing_edges = [existing_edge]
|
existing_edges = [existing_edge]
|
||||||
new_edges = [new_edge]
|
new_edges = [new_edge]
|
||||||
|
|
||||||
|
# Create a current episode and previous episodes
|
||||||
|
current_episode = EpisodicNode(
|
||||||
|
name='Current Episode',
|
||||||
|
content='This is the current episode content.',
|
||||||
|
created_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
source='test',
|
||||||
|
source_description='Test episode for unit testing',
|
||||||
|
)
|
||||||
|
previous_episodes = [
|
||||||
|
EpisodicNode(
|
||||||
|
name='Previous Episode 1',
|
||||||
|
content='This is the content of previous episode 1.',
|
||||||
|
created_at=now - timedelta(days=1),
|
||||||
|
valid_at=now - timedelta(days=1),
|
||||||
|
source='test',
|
||||||
|
source_description='Test previous episode 1 for unit testing',
|
||||||
|
),
|
||||||
|
EpisodicNode(
|
||||||
|
name='Previous Episode 2',
|
||||||
|
content='This is the content of previous episode 2.',
|
||||||
|
created_at=now - timedelta(days=2),
|
||||||
|
valid_at=now - timedelta(days=2),
|
||||||
|
source='test',
|
||||||
|
source_description='Test previous episode 2 for unit testing',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
# Call the function
|
# Call the function
|
||||||
result = prepare_invalidation_context(existing_edges, new_edges)
|
result = prepare_invalidation_context(
|
||||||
|
existing_edges, new_edges, current_episode, previous_episodes
|
||||||
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert 'existing_edges' in result
|
assert 'existing_edges' in result
|
||||||
assert 'new_edges' in result
|
assert 'new_edges' in result
|
||||||
|
assert 'current_episode' in result
|
||||||
|
assert 'previous_episodes' in result
|
||||||
assert len(result['existing_edges']) == 1
|
assert len(result['existing_edges']) == 1
|
||||||
assert len(result['new_edges']) == 1
|
assert len(result['new_edges']) == 1
|
||||||
|
assert result['current_episode'] == current_episode.content
|
||||||
|
assert len(result['previous_episodes']) == 2
|
||||||
|
|
||||||
# Check the format of the existing edge
|
# Check the format of the existing edge
|
||||||
existing_edge_str = result['existing_edges'][0]
|
existing_edge_str = result['existing_edges'][0]
|
||||||
|
|
@ -176,12 +209,25 @@ def test_prepare_invalidation_context():
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_invalidation_context_empty_input():
|
def test_prepare_invalidation_context_empty_input():
|
||||||
result = prepare_invalidation_context([], [])
|
now = datetime.now()
|
||||||
|
current_episode = EpisodicNode(
|
||||||
|
name='Current Episode',
|
||||||
|
content='Empty episode',
|
||||||
|
created_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
source='test',
|
||||||
|
source_description='Test empty episode for unit testing',
|
||||||
|
)
|
||||||
|
result = prepare_invalidation_context([], [], current_episode, [])
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert 'existing_edges' in result
|
assert 'existing_edges' in result
|
||||||
assert 'new_edges' in result
|
assert 'new_edges' in result
|
||||||
|
assert 'current_episode' in result
|
||||||
|
assert 'previous_episodes' in result
|
||||||
assert len(result['existing_edges']) == 0
|
assert len(result['existing_edges']) == 0
|
||||||
assert len(result['new_edges']) == 0
|
assert len(result['new_edges']) == 0
|
||||||
|
assert result['current_episode'] == current_episode.content
|
||||||
|
assert len(result['previous_episodes']) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_invalidation_context_sorting():
|
def test_prepare_invalidation_context_sorting():
|
||||||
|
|
@ -215,13 +261,36 @@ def test_prepare_invalidation_context_sorting():
|
||||||
# Prepare test input
|
# Prepare test input
|
||||||
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
||||||
|
|
||||||
|
# Create a current episode and previous episodes
|
||||||
|
current_episode = EpisodicNode(
|
||||||
|
name='Current Episode',
|
||||||
|
content='This is the current episode content.',
|
||||||
|
created_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
source='test',
|
||||||
|
source_description='Test episode for unit testing',
|
||||||
|
)
|
||||||
|
previous_episodes = [
|
||||||
|
EpisodicNode(
|
||||||
|
name='Previous Episode',
|
||||||
|
content='This is the content of a previous episode.',
|
||||||
|
created_at=now - timedelta(days=1),
|
||||||
|
valid_at=now - timedelta(days=1),
|
||||||
|
source='test',
|
||||||
|
source_description='Test previous episode for unit testing',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
# Call the function
|
# Call the function
|
||||||
result = prepare_invalidation_context(existing_edges, [])
|
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert len(result['existing_edges']) == 2
|
assert len(result['existing_edges']) == 2
|
||||||
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
|
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
|
||||||
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
|
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
|
||||||
|
assert result['current_episode'] == current_episode.content
|
||||||
|
assert len(result['previous_episodes']) == 1
|
||||||
|
assert result['previous_episodes'][0] == previous_episodes[0].content
|
||||||
|
|
||||||
|
|
||||||
# Run the tests
|
# Run the tests
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
from core.edges import EntityEdge
|
from core.edges import EntityEdge
|
||||||
from core.llm_client import LLMConfig, OpenAIClient
|
from core.llm_client import LLMConfig, OpenAIClient
|
||||||
from core.nodes import EntityNode
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
from core.utils.maintenance.temporal_operations import (
|
from core.utils.maintenance.temporal_operations import (
|
||||||
invalidate_edges,
|
invalidate_edges,
|
||||||
)
|
)
|
||||||
|
|
@ -24,7 +24,6 @@ def setup_llm_client():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Helper function to create test data
|
|
||||||
def create_test_data():
|
def create_test_data():
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
|
|
@ -53,15 +52,39 @@ def create_test_data():
|
||||||
existing_edge = (node1, edge1, node2)
|
existing_edge = (node1, edge1, node2)
|
||||||
new_edge = (node1, edge2, node2)
|
new_edge = (node1, edge2, node2)
|
||||||
|
|
||||||
return existing_edge, new_edge
|
# Create current episode
|
||||||
|
current_episode = EpisodicNode(
|
||||||
|
name='Current Episode',
|
||||||
|
content='Alice now dislikes Bob',
|
||||||
|
created_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
source='test',
|
||||||
|
source_description='Test episode for unit testing',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create previous episodes
|
||||||
|
previous_episodes = [
|
||||||
|
EpisodicNode(
|
||||||
|
name='Previous Episode',
|
||||||
|
content='Alice liked Bob',
|
||||||
|
created_at=now - timedelta(days=1),
|
||||||
|
valid_at=now - timedelta(days=1),
|
||||||
|
source='test',
|
||||||
|
source_description='Test previous episode for unit testing',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return existing_edge, new_edge, current_episode, previous_episodes
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges():
|
async def test_invalidate_edges():
|
||||||
existing_edge, new_edge = create_test_data()
|
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [new_edge])
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
|
||||||
|
)
|
||||||
|
|
||||||
assert len(invalidated_edges) == 1
|
assert len(invalidated_edges) == 1
|
||||||
assert invalidated_edges[0].uuid == existing_edge[1].uuid
|
assert invalidated_edges[0].uuid == existing_edge[1].uuid
|
||||||
|
|
@ -71,9 +94,11 @@ async def test_invalidate_edges():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges_no_invalidation():
|
async def test_invalidate_edges_no_invalidation():
|
||||||
existing_edge, _ = create_test_data()
|
existing_edge, _, current_episode, previous_episodes = create_test_data()
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [])
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
|
||||||
|
)
|
||||||
|
|
||||||
assert len(invalidated_edges) == 0
|
assert len(invalidated_edges) == 0
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue