Fix temporal invalidation unit tests (#23)

* wip

* wip

* wip

* fix: Linter errors

* fix formatting

* chore: fix ruff

* fix: Duplication

* chore: Fix unit tests for temporal invalidation

* attempt to fix unit tests

* fix: format

---------

Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
Pavlo Paliychuk 2024-08-22 19:02:20 -04:00 committed by GitHub
parent 72dfa3c1e3
commit 8a55f48f5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 107 additions and 13 deletions

View file

@ -1,9 +1,9 @@
import asyncio
import logging
from datetime import datetime, timezone
from typing import LiteralString
from neo4j import AsyncDriver
from typing_extensions import LiteralString
from core.nodes import EpisodicNode

View file

@ -3,7 +3,7 @@ from datetime import datetime, timedelta
import pytest
from core.edges import EntityEdge
from core.nodes import EntityNode
from core.nodes import EntityNode, EpisodicNode
from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation,
prepare_invalidation_context,
@ -114,7 +114,6 @@ def test_prepare_edges_for_invalidation_missing_nodes():
def test_prepare_invalidation_context():
# Create test data
now = datetime.now()
# Create nodes
@ -148,15 +147,49 @@ def test_prepare_invalidation_context():
existing_edges = [existing_edge]
new_edges = [new_edge]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode 1',
content='This is the content of previous episode 1.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode 1 for unit testing',
),
EpisodicNode(
name='Previous Episode 2',
content='This is the content of previous episode 2.',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source='test',
source_description='Test previous episode 2 for unit testing',
),
]
# Call the function
result = prepare_invalidation_context(existing_edges, new_edges)
result = prepare_invalidation_context(
existing_edges, new_edges, current_episode, previous_episodes
)
# Assert the result
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 1
assert len(result['new_edges']) == 1
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 2
# Check the format of the existing edge
existing_edge_str = result['existing_edges'][0]
@ -176,12 +209,25 @@ def test_prepare_invalidation_context():
def test_prepare_invalidation_context_empty_input():
result = prepare_invalidation_context([], [])
now = datetime.now()
current_episode = EpisodicNode(
name='Current Episode',
content='Empty episode',
created_at=now,
valid_at=now,
source='test',
source_description='Test empty episode for unit testing',
)
result = prepare_invalidation_context([], [], current_episode, [])
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 0
assert len(result['new_edges']) == 0
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 0
def test_prepare_invalidation_context_sorting():
@ -215,13 +261,36 @@ def test_prepare_invalidation_context_sorting():
# Prepare test input
existing_edges = [edge_with_nodes1, edge_with_nodes2]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='This is the content of a previous episode.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode for unit testing',
),
]
# Call the function
result = prepare_invalidation_context(existing_edges, [])
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
# Assert the result
assert len(result['existing_edges']) == 2
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 1
assert result['previous_episodes'][0] == previous_episodes[0].content
# Run the tests

View file

@ -6,7 +6,7 @@ from dotenv import load_dotenv
from core.edges import EntityEdge
from core.llm_client import LLMConfig, OpenAIClient
from core.nodes import EntityNode
from core.nodes import EntityNode, EpisodicNode
from core.utils.maintenance.temporal_operations import (
invalidate_edges,
)
@ -24,7 +24,6 @@ def setup_llm_client():
)
# Helper function to create test data
def create_test_data():
now = datetime.now()
@ -53,15 +52,39 @@ def create_test_data():
existing_edge = (node1, edge1, node2)
new_edge = (node1, edge2, node2)
return existing_edge, new_edge
# Create current episode
current_episode = EpisodicNode(
name='Current Episode',
content='Alice now dislikes Bob',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
# Create previous episodes
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='Alice liked Bob',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode for unit testing',
)
]
return existing_edge, new_edge, current_episode, previous_episodes
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges():
existing_edge, new_edge = create_test_data()
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [new_edge])
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge[1].uuid
@ -71,9 +94,11 @@ async def test_invalidate_edges():
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_invalidation():
existing_edge, _ = create_test_data()
existing_edge, _, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [])
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
)
assert len(invalidated_edges) == 0