Fix temporal invalidation unit tests (#23)

* wip

* wip

* wip

* fix: Linter errors

* fix formatting

* chore: fix ruff

* fix: Duplication

* chore: Fix unit tests for temporal invalidation

* attempt to fix unit tests

* fix: format

---------

Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
Pavlo Paliychuk 2024-08-22 19:02:20 -04:00 committed by GitHub
parent 72dfa3c1e3
commit 8a55f48f5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 107 additions and 13 deletions

View file

@ -1,9 +1,9 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import LiteralString
from neo4j import AsyncDriver from neo4j import AsyncDriver
from typing_extensions import LiteralString
from core.nodes import EpisodicNode from core.nodes import EpisodicNode

View file

@ -3,7 +3,7 @@ from datetime import datetime, timedelta
import pytest import pytest
from core.edges import EntityEdge from core.edges import EntityEdge
from core.nodes import EntityNode from core.nodes import EntityNode, EpisodicNode
from core.utils.maintenance.temporal_operations import ( from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation, prepare_edges_for_invalidation,
prepare_invalidation_context, prepare_invalidation_context,
@ -114,7 +114,6 @@ def test_prepare_edges_for_invalidation_missing_nodes():
def test_prepare_invalidation_context(): def test_prepare_invalidation_context():
# Create test data
now = datetime.now() now = datetime.now()
# Create nodes # Create nodes
@ -148,15 +147,49 @@ def test_prepare_invalidation_context():
existing_edges = [existing_edge] existing_edges = [existing_edge]
new_edges = [new_edge] new_edges = [new_edge]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode 1',
content='This is the content of previous episode 1.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode 1 for unit testing',
),
EpisodicNode(
name='Previous Episode 2',
content='This is the content of previous episode 2.',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source='test',
source_description='Test previous episode 2 for unit testing',
),
]
# Call the function # Call the function
result = prepare_invalidation_context(existing_edges, new_edges) result = prepare_invalidation_context(
existing_edges, new_edges, current_episode, previous_episodes
)
# Assert the result # Assert the result
assert isinstance(result, dict) assert isinstance(result, dict)
assert 'existing_edges' in result assert 'existing_edges' in result
assert 'new_edges' in result assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 1 assert len(result['existing_edges']) == 1
assert len(result['new_edges']) == 1 assert len(result['new_edges']) == 1
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 2
# Check the format of the existing edge # Check the format of the existing edge
existing_edge_str = result['existing_edges'][0] existing_edge_str = result['existing_edges'][0]
@ -176,12 +209,25 @@ def test_prepare_invalidation_context():
def test_prepare_invalidation_context_empty_input(): def test_prepare_invalidation_context_empty_input():
result = prepare_invalidation_context([], []) now = datetime.now()
current_episode = EpisodicNode(
name='Current Episode',
content='Empty episode',
created_at=now,
valid_at=now,
source='test',
source_description='Test empty episode for unit testing',
)
result = prepare_invalidation_context([], [], current_episode, [])
assert isinstance(result, dict) assert isinstance(result, dict)
assert 'existing_edges' in result assert 'existing_edges' in result
assert 'new_edges' in result assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 0 assert len(result['existing_edges']) == 0
assert len(result['new_edges']) == 0 assert len(result['new_edges']) == 0
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 0
def test_prepare_invalidation_context_sorting(): def test_prepare_invalidation_context_sorting():
@ -215,13 +261,36 @@ def test_prepare_invalidation_context_sorting():
# Prepare test input # Prepare test input
existing_edges = [edge_with_nodes1, edge_with_nodes2] existing_edges = [edge_with_nodes1, edge_with_nodes2]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='This is the content of a previous episode.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode for unit testing',
),
]
# Call the function # Call the function
result = prepare_invalidation_context(existing_edges, []) result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
# Assert the result # Assert the result
assert len(result['existing_edges']) == 2 assert len(result['existing_edges']) == 2
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 1
assert result['previous_episodes'][0] == previous_episodes[0].content
# Run the tests # Run the tests

View file

@ -6,7 +6,7 @@ from dotenv import load_dotenv
from core.edges import EntityEdge from core.edges import EntityEdge
from core.llm_client import LLMConfig, OpenAIClient from core.llm_client import LLMConfig, OpenAIClient
from core.nodes import EntityNode from core.nodes import EntityNode, EpisodicNode
from core.utils.maintenance.temporal_operations import ( from core.utils.maintenance.temporal_operations import (
invalidate_edges, invalidate_edges,
) )
@ -24,7 +24,6 @@ def setup_llm_client():
) )
# Helper function to create test data
def create_test_data(): def create_test_data():
now = datetime.now() now = datetime.now()
@ -53,15 +52,39 @@ def create_test_data():
existing_edge = (node1, edge1, node2) existing_edge = (node1, edge1, node2)
new_edge = (node1, edge2, node2) new_edge = (node1, edge2, node2)
return existing_edge, new_edge # Create current episode
current_episode = EpisodicNode(
name='Current Episode',
content='Alice now dislikes Bob',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
# Create previous episodes
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='Alice liked Bob',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode for unit testing',
)
]
return existing_edge, new_edge, current_episode, previous_episodes
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.integration @pytest.mark.integration
async def test_invalidate_edges(): async def test_invalidate_edges():
existing_edge, new_edge = create_test_data() existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [new_edge]) invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
)
assert len(invalidated_edges) == 1 assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge[1].uuid assert invalidated_edges[0].uuid == existing_edge[1].uuid
@ -71,9 +94,11 @@ async def test_invalidate_edges():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.integration @pytest.mark.integration
async def test_invalidate_edges_no_invalidation(): async def test_invalidate_edges_no_invalidation():
existing_edge, _ = create_test_data() existing_edge, _, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], []) invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
)
assert len(invalidated_edges) == 0 assert len(invalidated_edges) == 0