From e15c872900e652a04122126e36784c56711848ab Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 7 Oct 2024 11:45:31 -0400 Subject: [PATCH] Fix edge invalidation (#174) * update edge operations * add new tests --- graphiti_core/prompts/extract_edge_dates.py | 17 +- graphiti_core/prompts/invalidate_edges.py | 2 +- graphiti_core/utils/maintenance/__init__.py | 2 - .../utils/maintenance/edge_operations.py | 10 +- .../utils/maintenance/temporal_operations.py | 128 +----- .../utils/maintenance/test_edge_operations.py | 242 ++++++++++++ .../maintenance/test_temporal_operations.py | 368 ------------------ .../test_temporal_operations_int.py | 264 ++++++------- 8 files changed, 380 insertions(+), 653 deletions(-) create mode 100644 tests/utils/maintenance/test_edge_operations.py delete mode 100644 tests/utils/maintenance/test_temporal_operations.py diff --git a/graphiti_core/prompts/extract_edge_dates.py b/graphiti_core/prompts/extract_edge_dates.py index 4d6ab851..167911e5 100644 --- a/graphiti_core/prompts/extract_edge_dates.py +++ b/graphiti_core/prompts/extract_edge_dates.py @@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]: role='user', content=f""" Edge: - Edge Name: {context['edge_name']} Fact: {context['edge_fact']} Current Episode: {context['current_episode']} @@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]: Guidelines: 1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes. 2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates. - 3. If no temporal information is found that establishes or changes the relationship, leave the fields as null. - 4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship. - 5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp. - 6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date. - 7. If only a year is mentioned, use January 1st of that year at 00:00:00. + 3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date + 4. If no temporal information is found that establishes or changes the relationship, leave the fields as null. + 5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship. + 6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp. + 7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date. + 8. If only a year is mentioned, use January 1st of that year at 00:00:00. 9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned). Respond with a JSON object: {{ - "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null", - "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null", - "explanation": "Brief explanation of why these dates were chosen or why they were set to null" + "valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null", + "invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null", }} """, ), diff --git a/graphiti_core/prompts/invalidate_edges.py b/graphiti_core/prompts/invalidate_edges.py index 5724390b..d2246f89 100644 --- a/graphiti_core/prompts/invalidate_edges.py +++ b/graphiti_core/prompts/invalidate_edges.py @@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]: Message( role='user', content=f""" - Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge. + Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge. Existing Edges: {context['existing_edges']} diff --git a/graphiti_core/utils/maintenance/__init__.py b/graphiti_core/utils/maintenance/__init__.py index 553a203c..65fca2c2 100644 --- a/graphiti_core/utils/maintenance/__init__.py +++ b/graphiti_core/utils/maintenance/__init__.py @@ -4,7 +4,6 @@ from .graph_data_operations import ( retrieve_episodes, ) from .node_operations import extract_nodes -from .temporal_operations import invalidate_edges __all__ = [ 'extract_edges', @@ -12,5 +11,4 @@ __all__ = [ 'extract_nodes', 'clear_data', 'retrieve_episodes', - 'invalidate_edges', ] diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index e2953cb1..3ffe38e4 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -122,12 +122,6 @@ async def extract_edges( return edges -def create_edge_identifier( - source_node: EntityNode, edge: EntityEdge, target_node: EntityNode -) -> str: - return f'{source_node.name}-{edge.name}-{target_node.name}' - - async def dedupe_extracted_edges( llm_client: LLMClient, extracted_edges: list[EntityEdge], @@ -251,11 +245,11 @@ async def resolve_extracted_edge( if ( edge.invalid_at is not None and resolved_edge.valid_at is not None - and edge.invalid_at < resolved_edge.valid_at + and edge.invalid_at <= resolved_edge.valid_at ) or ( edge.valid_at is not None and resolved_edge.invalid_at is not None - and resolved_edge.invalid_at < edge.valid_at + and resolved_edge.invalid_at <= edge.valid_at ): continue # New edge invalidates edge diff --git a/graphiti_core/utils/maintenance/temporal_operations.py b/graphiti_core/utils/maintenance/temporal_operations.py index 6eaa22b4..fc84b191 100644 --- a/graphiti_core/utils/maintenance/temporal_operations.py +++ b/graphiti_core/utils/maintenance/temporal_operations.py @@ -21,129 +21,11 @@ from typing import List from graphiti_core.edges import EntityEdge from graphiti_core.llm_client import LLMClient -from graphiti_core.nodes import EntityNode, EpisodicNode +from graphiti_core.nodes import EpisodicNode from graphiti_core.prompts import prompt_library logger = logging.getLogger(__name__) -NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode] - - -def extract_node_and_edge_triplets( - edges: list[EntityEdge], nodes: list[EntityNode] -) -> list[NodeEdgeNodeTriplet]: - return [extract_node_edge_node_triplet(edge, nodes) for edge in edges] - - -def extract_node_edge_node_triplet( - edge: EntityEdge, nodes: list[EntityNode] -) -> NodeEdgeNodeTriplet: - source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) - target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) - if not source_node or not target_node: - raise ValueError(f'Source or target node not found for edge {edge.uuid}') - return (source_node, edge, target_node) - - -def prepare_edges_for_invalidation( - existing_edges: list[EntityEdge], - new_edges: list[EntityEdge], - nodes: list[EntityNode], -) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]: - existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = [] - new_edges_with_nodes: list[NodeEdgeNodeTriplet] = [] - - for edge_list, result_list in [ - (existing_edges, existing_edges_pending_invalidation), - (new_edges, new_edges_with_nodes), - ]: - for edge in edge_list: - source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) - target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) - - if source_node and target_node: - result_list.append((source_node, edge, target_node)) - - return existing_edges_pending_invalidation, new_edges_with_nodes - - -async def invalidate_edges( - llm_client: LLMClient, - existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], -) -> list[EntityEdge]: - invalidated_edges = [] # TODO: this is not yet used? - - context = prepare_invalidation_context( - existing_edges_pending_invalidation, - new_edges, - current_episode, - previous_episodes, - ) - llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context)) - - edges_to_invalidate = llm_response.get('invalidated_edges', []) - invalidated_edges = process_edge_invalidation_llm_response( - edges_to_invalidate, existing_edges_pending_invalidation - ) - - return invalidated_edges - - -def extract_date_strings_from_edge(edge: EntityEdge) -> str: - start = edge.valid_at - end = edge.invalid_at - date_string = f'Start Date: {start.isoformat()}' if start else '' - if end: - date_string += f' (End Date: {end.isoformat()})' - return date_string - - -def prepare_invalidation_context( - existing_edges: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], -) -> dict: - return { - 'existing_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}' - for source_node, edge, target_node in sorted( - existing_edges, key=lambda x: (x[1].created_at), reverse=True - ) - ], - 'new_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}' - for source_node, edge, target_node in sorted( - new_edges, key=lambda x: (x[1].created_at), reverse=True - ) - ], - 'current_episode': current_episode.content, - 'previous_episodes': [episode.content for episode in previous_episodes], - } - - -def process_edge_invalidation_llm_response( - edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] -) -> List[EntityEdge]: - invalidated_edges = [] - for edge_to_invalidate in edges_to_invalidate: - edge_uuid = edge_to_invalidate['edge_uuid'] - edge_to_update = next( - (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), - None, - ) - if edge_to_update: - edge_to_update.expired_at = datetime.now() - edge_to_update.fact = edge_to_invalidate['fact'] - invalidated_edges.append(edge_to_update) - logger.info( - f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" - ) - return invalidated_edges - async def extract_edge_dates( llm_client: LLMClient, @@ -152,7 +34,6 @@ async def extract_edge_dates( previous_episodes: List[EpisodicNode], ) -> tuple[datetime | None, datetime | None]: context = { - 'edge_name': edge.name, 'edge_fact': edge.fact, 'current_episode': current_episode.content, 'previous_episodes': [ep.content for ep in previous_episodes], @@ -162,25 +43,22 @@ async def extract_edge_dates( valid_at = llm_response.get('valid_at') invalid_at = llm_response.get('invalid_at') - explanation = llm_response.get('explanation', '') valid_at_datetime = None invalid_at_datetime = None - if valid_at and valid_at != '': + if valid_at: try: valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00')) except ValueError as e: logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}') - if invalid_at and invalid_at != '': + if invalid_at: try: invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00')) except ValueError as e: logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}') - logger.info(f'Edge date extraction explanation: {explanation}') - return valid_at_datetime, invalid_at_datetime diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/utils/maintenance/test_edge_operations.py new file mode 100644 index 00000000..67fcfc5e --- /dev/null +++ b/tests/utils/maintenance/test_edge_operations.py @@ -0,0 +1,242 @@ +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pytest import MonkeyPatch + +from graphiti_core.edges import EntityEdge +from graphiti_core.nodes import EpisodicNode +from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge + + +@pytest.fixture +def mock_llm_client(): + return MagicMock() + + +@pytest.fixture +def mock_extracted_edge(): + return EntityEdge( + source_node_uuid='source_uuid', + target_node_uuid='target_uuid', + name='test_edge', + group_id='group_1', + fact='Test fact', + episodes=['episode_1'], + created_at=datetime.now(), + valid_at=None, + invalid_at=None, + ) + + +@pytest.fixture +def mock_related_edges(): + return [ + EntityEdge( + source_node_uuid='source_uuid_2', + target_node_uuid='target_uuid_2', + name='related_edge', + group_id='group_1', + fact='Related fact', + episodes=['episode_2'], + created_at=datetime.now() - timedelta(days=1), + valid_at=datetime.now() - timedelta(days=1), + invalid_at=None, + ) + ] + + +@pytest.fixture +def mock_existing_edges(): + return [ + EntityEdge( + source_node_uuid='source_uuid_3', + target_node_uuid='target_uuid_3', + name='existing_edge', + group_id='group_1', + fact='Existing fact', + episodes=['episode_3'], + created_at=datetime.now() - timedelta(days=2), + valid_at=datetime.now() - timedelta(days=2), + invalid_at=None, + ) + ] + + +@pytest.fixture +def mock_current_episode(): + return EpisodicNode( + uuid='episode_1', + content='Current episode content', + valid_at=datetime.now(), + name='Current Episode', + group_id='group_1', + source='message', + source_description='Test source description', + ) + + +@pytest.fixture +def mock_previous_episodes(): + return [ + EpisodicNode( + uuid='episode_2', + content='Previous episode content', + valid_at=datetime.now() - timedelta(days=1), + name='Previous Episode', + group_id='group_1', + source='message', + source_description='Test source description', + ) + ] + + +@pytest.mark.asyncio +async def test_resolve_extracted_edge_no_changes( + mock_llm_client, + mock_extracted_edge, + mock_related_edges, + mock_existing_edges, + mock_current_episode, + mock_previous_episodes, + monkeypatch: MonkeyPatch, +): + # Mock the function calls + dedupe_mock = AsyncMock(return_value=mock_extracted_edge) + extract_dates_mock = AsyncMock(return_value=(None, None)) + get_contradictions_mock = AsyncMock(return_value=[]) + + # Patch the function calls + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', + get_contradictions_mock, + ) + + resolved_edge, invalidated_edges = await resolve_extracted_edge( + mock_llm_client, + mock_extracted_edge, + mock_related_edges, + mock_existing_edges, + mock_current_episode, + mock_previous_episodes, + ) + + assert resolved_edge.uuid == mock_extracted_edge.uuid + assert invalidated_edges == [] + dedupe_mock.assert_called_once() + extract_dates_mock.assert_called_once() + get_contradictions_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_resolve_extracted_edge_with_dates( + mock_llm_client, + mock_extracted_edge, + mock_related_edges, + mock_existing_edges, + mock_current_episode, + mock_previous_episodes, + monkeypatch: MonkeyPatch, +): + valid_at = datetime.now() - timedelta(days=1) + invalid_at = datetime.now() + timedelta(days=1) + + # Mock the function calls + dedupe_mock = AsyncMock(return_value=mock_extracted_edge) + extract_dates_mock = AsyncMock(return_value=(valid_at, invalid_at)) + get_contradictions_mock = AsyncMock(return_value=[]) + + # Patch the function calls + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', + get_contradictions_mock, + ) + + resolved_edge, invalidated_edges = await resolve_extracted_edge( + mock_llm_client, + mock_extracted_edge, + mock_related_edges, + mock_existing_edges, + mock_current_episode, + mock_previous_episodes, + ) + + assert resolved_edge.valid_at == valid_at + assert resolved_edge.invalid_at == invalid_at + assert resolved_edge.expired_at is not None + assert invalidated_edges == [] + + +@pytest.mark.asyncio +async def test_resolve_extracted_edge_with_invalidation( + mock_llm_client, + mock_extracted_edge, + mock_related_edges, + mock_existing_edges, + mock_current_episode, + mock_previous_episodes, + monkeypatch: MonkeyPatch, +): + valid_at = datetime.now() - timedelta(days=1) + mock_extracted_edge.valid_at = valid_at + + invalidation_candidate = EntityEdge( + source_node_uuid='source_uuid_4', + target_node_uuid='target_uuid_4', + name='invalidation_candidate', + group_id='group_1', + fact='Invalidation candidate fact', + episodes=['episode_4'], + created_at=datetime.now(), + valid_at=datetime.now() - timedelta(days=2), + invalid_at=None, + ) + + # Mock the function calls + dedupe_mock = AsyncMock(return_value=mock_extracted_edge) + extract_dates_mock = AsyncMock(return_value=(None, None)) + get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate]) + + # Patch the function calls + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', + get_contradictions_mock, + ) + + resolved_edge, invalidated_edges = await resolve_extracted_edge( + mock_llm_client, + mock_extracted_edge, + mock_related_edges, + mock_existing_edges, + mock_current_episode, + mock_previous_episodes, + ) + + assert resolved_edge.uuid == mock_extracted_edge.uuid + assert len(invalidated_edges) == 1 + assert invalidated_edges[0].uuid == invalidation_candidate.uuid + assert invalidated_edges[0].invalid_at == valid_at + assert invalidated_edges[0].expired_at is not None + + +# Run the tests +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py deleted file mode 100644 index 0e86bd89..00000000 --- a/tests/utils/maintenance/test_temporal_operations.py +++ /dev/null @@ -1,368 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import unittest -from datetime import datetime, timedelta - -import pytest - -from graphiti_core.edges import EntityEdge -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode -from graphiti_core.utils.maintenance.temporal_operations import ( - extract_date_strings_from_edge, - prepare_edges_for_invalidation, - prepare_invalidation_context, -) - - -# Helper function to create test data -def create_test_data(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1') - - # Create edges - existing_edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='KNOWS', - fact='Node1 knows Node2', - created_at=now, - group_id='1', - ) - existing_edge2 = EntityEdge( - uuid='e2', - source_node_uuid='2', - target_node_uuid='3', - name='LIKES', - fact='Node2 likes Node3', - created_at=now, - group_id='1', - ) - new_edge1 = EntityEdge( - uuid='e3', - source_node_uuid='1', - target_node_uuid='3', - name='WORKS_WITH', - fact='Node1 works with Node3', - created_at=now, - group_id='1', - ) - new_edge2 = EntityEdge( - uuid='e4', - source_node_uuid='1', - target_node_uuid='2', - name='DISLIKES', - fact='Node1 dislikes Node2', - created_at=now, - group_id='1', - ) - - return { - 'nodes': [node1, node2, node3], - 'existing_edges': [existing_edge1, existing_edge2], - 'new_edges': [new_edge1, new_edge2], - } - - -def test_prepare_edges_for_invalidation_basic(): - test_data = create_test_data() - - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - test_data['existing_edges'], test_data['new_edges'], test_data['nodes'] - ) - - assert len(existing_edges_pending_invalidation) == 2 - assert len(new_edges_with_nodes) == 2 - - # Check if the edges are correctly associated with nodes - for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes: - assert isinstance(edge_with_nodes[0], EntityNode) - assert isinstance(edge_with_nodes[1], EntityEdge) - assert isinstance(edge_with_nodes[2], EntityNode) - - -def test_prepare_edges_for_invalidation_no_existing_edges(): - test_data = create_test_data() - - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - [], test_data['new_edges'], test_data['nodes'] - ) - - assert len(existing_edges_pending_invalidation) == 0 - assert len(new_edges_with_nodes) == 2 - - -def test_prepare_edges_for_invalidation_no_new_edges(): - test_data = create_test_data() - - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - test_data['existing_edges'], [], test_data['nodes'] - ) - - assert len(existing_edges_pending_invalidation) == 2 - assert len(new_edges_with_nodes) == 0 - - -def test_prepare_edges_for_invalidation_missing_nodes(): - test_data = create_test_data() - - # Remove one node to simulate a missing node scenario - nodes = test_data['nodes'][:-1] - - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - test_data['existing_edges'], test_data['new_edges'], nodes - ) - - assert len(existing_edges_pending_invalidation) == 1 - assert len(new_edges_with_nodes) == 1 - - -def test_prepare_invalidation_context(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1') - - # Create edges - edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='KNOWS', - fact='Node1 knows Node2', - created_at=now, - group_id='1', - ) - edge2 = EntityEdge( - uuid='e2', - source_node_uuid='2', - target_node_uuid='3', - name='LIKES', - fact='Node2 likes Node3', - created_at=now, - group_id='1', - ) - - # Create NodeEdgeNodeTriplet objects - existing_edge = (node1, edge1, node2) - new_edge = (node2, edge2, node3) - - # Prepare test input - existing_edges = [existing_edge] - new_edges = [new_edge] - - # Create a current episode and previous episodes - current_episode = EpisodicNode( - name='Current Episode', - content='This is the current episode content.', - created_at=now, - valid_at=now, - source=EpisodeType.message, - source_description='Test episode for unit testing', - group_id='1', - ) - previous_episodes = [ - EpisodicNode( - name='Previous Episode 1', - content='This is the content of previous episode 1.', - created_at=now - timedelta(days=1), - valid_at=now - timedelta(days=1), - source=EpisodeType.message, - source_description='Test previous episode 1 for unit testing', - group_id='1', - ), - EpisodicNode( - name='Previous Episode 2', - content='This is the content of previous episode 2.', - created_at=now - timedelta(days=2), - valid_at=now - timedelta(days=2), - source=EpisodeType.message, - source_description='Test previous episode 2 for unit testing', - group_id='1', - ), - ] - - # Call the function - result = prepare_invalidation_context( - existing_edges, new_edges, current_episode, previous_episodes - ) - - # Assert the result - assert isinstance(result, dict) - assert 'existing_edges' in result - assert 'new_edges' in result - assert 'current_episode' in result - assert 'previous_episodes' in result - assert len(result['existing_edges']) == 1 - assert len(result['new_edges']) == 1 - assert result['current_episode'] == current_episode.content - assert len(result['previous_episodes']) == 2 - - # Check the format of the existing edge - existing_edge_str = result['existing_edges'][0] - assert edge1.uuid in existing_edge_str - assert node1.name in existing_edge_str - assert edge1.name in existing_edge_str - assert node2.name in existing_edge_str - assert edge1.fact in existing_edge_str - - # Check the format of the new edge - new_edge_str = result['new_edges'][0] - assert edge2.uuid in new_edge_str - assert node2.name in new_edge_str - assert edge2.name in new_edge_str - assert node3.name in new_edge_str - assert edge2.fact in new_edge_str - - -def test_prepare_invalidation_context_empty_input(): - now = datetime.now() - current_episode = EpisodicNode( - name='Current Episode', - content='Empty episode', - created_at=now, - valid_at=now, - source=EpisodeType.message, - source_description='Test empty episode for unit testing', - group_id='1', - ) - result = prepare_invalidation_context([], [], current_episode, []) - assert isinstance(result, dict) - assert 'existing_edges' in result - assert 'new_edges' in result - assert 'current_episode' in result - assert 'previous_episodes' in result - assert len(result['existing_edges']) == 0 - assert len(result['new_edges']) == 0 - assert result['current_episode'] == current_episode.content - assert len(result['previous_episodes']) == 0 - - -def test_prepare_invalidation_context_sorting(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') - - # Create edges with different timestamps - edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='KNOWS', - fact='Node1 knows Node2', - created_at=now, - group_id='1', - ) - edge2 = EntityEdge( - uuid='e2', - source_node_uuid='2', - target_node_uuid='1', - name='LIKES', - fact='Node2 likes Node1', - created_at=now + timedelta(hours=1), - group_id='1', - ) - - edge_with_nodes1 = (node1, edge1, node2) - edge_with_nodes2 = (node2, edge2, node1) - - # Prepare test input - existing_edges = [edge_with_nodes1, edge_with_nodes2] - - # Create a current episode and previous episodes - current_episode = EpisodicNode( - name='Current Episode', - content='This is the current episode content.', - created_at=now, - valid_at=now, - source=EpisodeType.message, - source_description='Test episode for unit testing', - group_id='1', - ) - previous_episodes = [ - EpisodicNode( - name='Previous Episode', - content='This is the content of a previous episode.', - created_at=now - timedelta(days=1), - valid_at=now - timedelta(days=1), - source=EpisodeType.message, - source_description='Test previous episode for unit testing', - group_id='1', - ), - ] - - # Call the function - result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes) - - # Assert the result - assert len(result['existing_edges']) == 2 - assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first - assert edge1.uuid in result['existing_edges'][1] # The older edge should be second - assert result['current_episode'] == current_episode.content - assert len(result['previous_episodes']) == 1 - assert result['previous_episodes'][0] == previous_episodes[0].content - - -class TestExtractDateStringsFromEdge(unittest.TestCase): - def generate_entity_edge(self, valid_at, invalid_at): - return EntityEdge( - source_node_uuid='1', - target_node_uuid='2', - name='KNOWS', - fact='Node1 knows Node2', - created_at=datetime.now(), - valid_at=valid_at, - invalid_at=invalid_at, - group_id='1', - ) - - def test_both_dates_present(self): - edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0)) - result = extract_date_strings_from_edge(edge) - expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)' - self.assertEqual(result, expected) - - def test_only_valid_at_present(self): - edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None) - result = extract_date_strings_from_edge(edge) - expected = 'Start Date: 2024-01-01T12:00:00' - self.assertEqual(result, expected) - - def test_only_invalid_at_present(self): - edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0)) - result = extract_date_strings_from_edge(edge) - expected = ' (End Date: 2024-01-02T12:00:00)' - self.assertEqual(result, expected) - - def test_no_dates_present(self): - edge = self.generate_entity_edge(None, None) - result = extract_date_strings_from_edge(edge) - expected = '' - self.assertEqual(result, expected) - - -# Run the tests -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index b08689fb..98ae5d0c 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -19,12 +19,14 @@ from datetime import datetime, timedelta import pytest from dotenv import load_dotenv +from pytz import UTC from graphiti_core.edges import EntityEdge from graphiti_core.llm_client import LLMConfig, OpenAIClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.utils.maintenance.temporal_operations import ( - invalidate_edges, + extract_edge_dates, + get_edge_contradictions, ) load_dotenv() @@ -43,31 +45,26 @@ def setup_llm_client(): def create_test_data(): now = datetime.now() - # Create nodes - node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) - # Create edges - edge1 = EntityEdge( + existing_edge = EntityEdge( uuid='e1', source_node_uuid='1', target_node_uuid='2', name='LIKES', fact='Alice likes Bob', created_at=now - timedelta(days=1), + group_id='1', ) - edge2 = EntityEdge( + new_edge = EntityEdge( uuid='e2', source_node_uuid='1', target_node_uuid='2', name='DISLIKES', fact='Alice dislikes Bob', created_at=now, + group_id='1', ) - existing_edge = (node1, edge1, node2) - new_edge = (node1, edge2, node2) - # Create current episode current_episode = EpisodicNode( name='Current Episode', @@ -97,46 +94,40 @@ def create_test_data(): @pytest.mark.asyncio @pytest.mark.integration -async def test_invalidate_edges(): +async def test_get_edge_contradictions(): existing_edge, new_edge, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges( - setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes - ) + invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [existing_edge]) assert len(invalidated_edges) == 1 - assert invalidated_edges[0].uuid == existing_edge[1].uuid - assert invalidated_edges[0].expired_at is not None + assert invalidated_edges[0].uuid == existing_edge.uuid @pytest.mark.asyncio @pytest.mark.integration -async def test_invalidate_edges_no_invalidation(): - existing_edge, _, current_episode, previous_episodes = create_test_data() +async def test_get_edge_contradictions_no_contradictions(): + _, new_edge, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges( - setup_llm_client(), [existing_edge], [], current_episode, previous_episodes - ) + invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, []) assert len(invalidated_edges) == 0 @pytest.mark.asyncio @pytest.mark.integration -async def test_invalidate_edges_multiple_existing(): - existing_edge1, new_edge = create_test_data() - existing_edge2, _ = create_test_data() - existing_edge2[1].uuid = 'e3' - existing_edge2[1].name = 'KNOWS' - existing_edge2[1].fact = 'Alice knows Bob' +async def test_get_edge_contradictions_multiple_existing(): + existing_edge1, new_edge, _, _ = create_test_data() + existing_edge2, _, _, _ = create_test_data() + existing_edge2.uuid = 'e3' + existing_edge2.name = 'KNOWS' + existing_edge2.fact = 'Alice knows Bob' - invalidated_edges = await invalidate_edges( - setup_llm_client(), [existing_edge1, existing_edge2], [new_edge] + invalidated_edges = await get_edge_contradictions( + setup_llm_client(), new_edge, [existing_edge1, existing_edge2] ) assert len(invalidated_edges) == 1 - assert invalidated_edges[0].uuid == existing_edge1[1].uuid - assert invalidated_edges[0].expired_at is not None + assert invalidated_edges[0].uuid == existing_edge1.uuid # Helper function to create more complex test data @@ -152,7 +143,7 @@ def create_complex_test_data(): ) # Create edges - edge1 = EntityEdge( + existing_edge1 = EntityEdge( uuid='e1', source_node_uuid='1', target_node_uuid='2', @@ -161,7 +152,7 @@ def create_complex_test_data(): group_id='1', created_at=now - timedelta(days=5), ) - edge2 = EntityEdge( + existing_edge2 = EntityEdge( uuid='e2', source_node_uuid='1', target_node_uuid='3', @@ -170,7 +161,7 @@ def create_complex_test_data(): group_id='1', created_at=now - timedelta(days=3), ) - edge3 = EntityEdge( + existing_edge3 = EntityEdge( uuid='e3', source_node_uuid='2', target_node_uuid='4', @@ -180,10 +171,6 @@ def create_complex_test_data(): created_at=now - timedelta(days=2), ) - existing_edge1 = (node1, edge1, node2) - existing_edge2 = (node1, edge2, node3) - existing_edge3 = (node2, edge3, node4) - return [existing_edge1, existing_edge2, existing_edge3], [ node1, node2, @@ -198,118 +185,61 @@ async def test_invalidate_edges_complex(): existing_edges, nodes = create_complex_test_data() # Create a new edge that contradicts an existing one - new_edge = ( - nodes[0], - EntityEdge( - uuid='e4', - source_node_uuid='1', - target_node_uuid='2', - name='DISLIKES', - fact='Alice dislikes Bob', - group_id='1', - created_at=datetime.now(), - ), - nodes[1], + new_edge = EntityEdge( + uuid='e4', + source_node_uuid='1', + target_node_uuid='2', + name='DISLIKES', + fact='Alice dislikes Bob', + group_id='1', + created_at=datetime.now(), ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) assert len(invalidated_edges) == 1 assert invalidated_edges[0].uuid == 'e1' - assert invalidated_edges[0].expired_at is not None @pytest.mark.asyncio @pytest.mark.integration -async def test_invalidate_edges_temporal_update(): +async def test_get_edge_contradictions_temporal_update(): existing_edges, nodes = create_complex_test_data() # Create a new edge that updates an existing one with new information - new_edge = ( - nodes[1], - EntityEdge( - uuid='e5', - source_node_uuid='2', - target_node_uuid='4', - name='LEFT_JOB', - fact='Bob left his job at Company XYZ', - group_id='1', - created_at=datetime.now(), - ), - nodes[3], + new_edge = EntityEdge( + uuid='e5', + source_node_uuid='2', + target_node_uuid='4', + name='LEFT_JOB', + fact='Bob no longer works at at Company XYZ', + group_id='1', + created_at=datetime.now(), ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) assert len(invalidated_edges) == 1 assert invalidated_edges[0].uuid == 'e3' - assert invalidated_edges[0].expired_at is not None @pytest.mark.asyncio @pytest.mark.integration -async def test_invalidate_edges_multiple_invalidations(): - existing_edges, nodes = create_complex_test_data() - - # Create new edges that invalidate multiple existing edges - new_edge1 = ( - nodes[0], - EntityEdge( - uuid='e6', - source_node_uuid='1', - target_node_uuid='2', - name='ENEMIES_WITH', - fact='Alice and Bob are now enemies', - group_id='1', - created_at=datetime.now(), - ), - nodes[1], - ) - new_edge2 = ( - nodes[0], - EntityEdge( - uuid='e7', - source_node_uuid='1', - target_node_uuid='3', - name='ENDED_FRIENDSHIP', - fact='Alice ended her friendship with Charlie', - group_id='1', - created_at=datetime.now(), - ), - nodes[2], - ) - - invalidated_edges = await invalidate_edges( - setup_llm_client(), existing_edges, [new_edge1, new_edge2] - ) - - assert len(invalidated_edges) == 2 - assert set(edge.uuid for edge in invalidated_edges) == {'e1', 'e2'} - for edge in invalidated_edges: - assert edge.expired_at is not None - - -@pytest.mark.asyncio -@pytest.mark.integration -async def test_invalidate_edges_no_effect(): +async def test_get_edge_contradictions_no_effect(): existing_edges, nodes = create_complex_test_data() # Create a new edge that doesn't invalidate any existing edges - new_edge = ( - nodes[2], - EntityEdge( - uuid='e8', - source_node_uuid='3', - target_node_uuid='4', - name='APPLIED_TO', - fact='Charlie applied to Company XYZ', - group_id='1', - created_at=datetime.now(), - ), - nodes[3], + new_edge = EntityEdge( + uuid='e8', + source_node_uuid='3', + target_node_uuid='4', + name='APPLIED_TO', + fact='Charlie applied to Company XYZ', + group_id='1', + created_at=datetime.now(), ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) assert len(invalidated_edges) == 0 @@ -320,28 +250,82 @@ async def test_invalidate_edges_partial_update(): existing_edges, nodes = create_complex_test_data() # Create a new edge that partially updates an existing one - new_edge = ( - nodes[1], - EntityEdge( - uuid='e9', - source_node_uuid='2', - target_node_uuid='4', - name='CHANGED_POSITION', - fact='Bob changed his position at Company XYZ', - group_id='1', - created_at=datetime.now(), - ), - nodes[3], + new_edge = EntityEdge( + uuid='e9', + source_node_uuid='2', + target_node_uuid='4', + name='CHANGED_POSITION', + fact='Bob changed his position at Company XYZ', + group_id='1', + created_at=datetime.now(), ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated +def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]: + now = datetime.now(UTC) + + previous_episodes = [ + EpisodicNode( + name='Previous Episode 1', + content='Bob: I work at XYZ company', + created_at=now - timedelta(days=2), + valid_at=now - timedelta(days=2), + source=EpisodeType.message, + source_description='Test previous episode for unit testing', + group_id='1', + ), + EpisodicNode( + name='Previous Episode 2', + content="Alice: That's really cool!", + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source=EpisodeType.message, + source_description='Test previous episode for unit testing', + group_id='1', + ), + ] + + episode = EpisodicNode( + name='Previous Episode', + content='Bob: It was cool, but I no longer work at company XYZ', + created_at=now, + valid_at=now, + source=EpisodeType.message, + source_description='Test previous episode for unit testing', + group_id='1', + ) + + return episode, previous_episodes + + @pytest.mark.asyncio @pytest.mark.integration -async def test_invalidate_edges_empty_inputs(): - invalidated_edges = await invalidate_edges(setup_llm_client(), [], []) +async def test_extract_edge_dates(): + episode, previous_episodes = create_data_for_temporal_extraction() - assert len(invalidated_edges) == 0 + # Create a new edge that partially updates an existing one + new_edge = EntityEdge( + uuid='e9', + source_node_uuid='2', + target_node_uuid='4', + name='LEFT_JOB', + fact='Bob no longer works at Company XYZ', + group_id='1', + created_at=datetime.now(UTC), + ) + + valid_at, invalid_at = await extract_edge_dates( + setup_llm_client(), new_edge, episode, previous_episodes + ) + + assert valid_at == episode.valid_at + assert invalid_at is None + + +# Run the tests +if __name__ == '__main__': + pytest.main([__file__])