Fix edge invalidation (#174)

* update edge operations

* add new tests
This commit is contained in:
Preston Rasmussen 2024-10-07 11:45:31 -04:00 committed by GitHub
parent 377225eec5
commit e15c872900
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 380 additions and 653 deletions

View file

@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
role='user',
content=f"""
Edge:
Edge Name: {context['edge_name']}
Fact: {context['edge_fact']}
Current Episode: {context['current_episode']}
@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
3. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
7. If only a year is mentioned, use January 1st of that year at 00:00:00.
3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
8. If only a year is mentioned, use January 1st of that year at 00:00:00.
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
Respond with a JSON object:
{{
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
"explanation": "Brief explanation of why these dates were chosen or why they were set to null"
"valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
"invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
}}
""",
),

View file

@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
Existing Edges:
{context['existing_edges']}

View file

@ -4,7 +4,6 @@ from .graph_data_operations import (
retrieve_episodes,
)
from .node_operations import extract_nodes
from .temporal_operations import invalidate_edges
__all__ = [
'extract_edges',
@ -12,5 +11,4 @@ __all__ = [
'extract_nodes',
'clear_data',
'retrieve_episodes',
'invalidate_edges',
]

View file

@ -122,12 +122,6 @@ async def extract_edges(
return edges
def create_edge_identifier(
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
) -> str:
return f'{source_node.name}-{edge.name}-{target_node.name}'
async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
@ -251,11 +245,11 @@ async def resolve_extracted_edge(
if (
edge.invalid_at is not None
and resolved_edge.valid_at is not None
and edge.invalid_at < resolved_edge.valid_at
and edge.invalid_at <= resolved_edge.valid_at
) or (
edge.valid_at is not None
and resolved_edge.invalid_at is not None
and resolved_edge.invalid_at < edge.valid_at
and resolved_edge.invalid_at <= edge.valid_at
):
continue
# New edge invalidates edge

View file

@ -21,129 +21,11 @@ from typing import List
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import EpisodicNode
from graphiti_core.prompts import prompt_library
logger = logging.getLogger(__name__)
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
def extract_node_and_edge_triplets(
edges: list[EntityEdge], nodes: list[EntityNode]
) -> list[NodeEdgeNodeTriplet]:
return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
def extract_node_edge_node_triplet(
edge: EntityEdge, nodes: list[EntityNode]
) -> NodeEdgeNodeTriplet:
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
if not source_node or not target_node:
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
return (source_node, edge, target_node)
def prepare_edges_for_invalidation(
existing_edges: list[EntityEdge],
new_edges: list[EntityEdge],
nodes: list[EntityNode],
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
for edge_list, result_list in [
(existing_edges, existing_edges_pending_invalidation),
(new_edges, new_edges_with_nodes),
]:
for edge in edge_list:
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
if source_node and target_node:
result_list.append((source_node, edge, target_node))
return existing_edges_pending_invalidation, new_edges_with_nodes
async def invalidate_edges(
llm_client: LLMClient,
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
new_edges: list[NodeEdgeNodeTriplet],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
invalidated_edges = [] # TODO: this is not yet used?
context = prepare_invalidation_context(
existing_edges_pending_invalidation,
new_edges,
current_episode,
previous_episodes,
)
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
edges_to_invalidate = llm_response.get('invalidated_edges', [])
invalidated_edges = process_edge_invalidation_llm_response(
edges_to_invalidate, existing_edges_pending_invalidation
)
return invalidated_edges
def extract_date_strings_from_edge(edge: EntityEdge) -> str:
start = edge.valid_at
end = edge.invalid_at
date_string = f'Start Date: {start.isoformat()}' if start else ''
if end:
date_string += f' (End Date: {end.isoformat()})'
return date_string
def prepare_invalidation_context(
existing_edges: list[NodeEdgeNodeTriplet],
new_edges: list[NodeEdgeNodeTriplet],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> dict:
return {
'existing_edges': [
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
for source_node, edge, target_node in sorted(
existing_edges, key=lambda x: (x[1].created_at), reverse=True
)
],
'new_edges': [
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
for source_node, edge, target_node in sorted(
new_edges, key=lambda x: (x[1].created_at), reverse=True
)
],
'current_episode': current_episode.content,
'previous_episodes': [episode.content for episode in previous_episodes],
}
def process_edge_invalidation_llm_response(
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
) -> List[EntityEdge]:
invalidated_edges = []
for edge_to_invalidate in edges_to_invalidate:
edge_uuid = edge_to_invalidate['edge_uuid']
edge_to_update = next(
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
None,
)
if edge_to_update:
edge_to_update.expired_at = datetime.now()
edge_to_update.fact = edge_to_invalidate['fact']
invalidated_edges.append(edge_to_update)
logger.info(
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
)
return invalidated_edges
async def extract_edge_dates(
llm_client: LLMClient,
@ -152,7 +34,6 @@ async def extract_edge_dates(
previous_episodes: List[EpisodicNode],
) -> tuple[datetime | None, datetime | None]:
context = {
'edge_name': edge.name,
'edge_fact': edge.fact,
'current_episode': current_episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
@ -162,25 +43,22 @@ async def extract_edge_dates(
valid_at = llm_response.get('valid_at')
invalid_at = llm_response.get('invalid_at')
explanation = llm_response.get('explanation', '')
valid_at_datetime = None
invalid_at_datetime = None
if valid_at and valid_at != '':
if valid_at:
try:
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
except ValueError as e:
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
if invalid_at and invalid_at != '':
if invalid_at:
try:
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
except ValueError as e:
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
logger.info(f'Edge date extraction explanation: {explanation}')
return valid_at_datetime, invalid_at_datetime

View file

@ -0,0 +1,242 @@
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from pytest import MonkeyPatch
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodicNode
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
@pytest.fixture
def mock_llm_client():
return MagicMock()
@pytest.fixture
def mock_extracted_edge():
return EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='Test fact',
episodes=['episode_1'],
created_at=datetime.now(),
valid_at=None,
invalid_at=None,
)
@pytest.fixture
def mock_related_edges():
return [
EntityEdge(
source_node_uuid='source_uuid_2',
target_node_uuid='target_uuid_2',
name='related_edge',
group_id='group_1',
fact='Related fact',
episodes=['episode_2'],
created_at=datetime.now() - timedelta(days=1),
valid_at=datetime.now() - timedelta(days=1),
invalid_at=None,
)
]
@pytest.fixture
def mock_existing_edges():
return [
EntityEdge(
source_node_uuid='source_uuid_3',
target_node_uuid='target_uuid_3',
name='existing_edge',
group_id='group_1',
fact='Existing fact',
episodes=['episode_3'],
created_at=datetime.now() - timedelta(days=2),
valid_at=datetime.now() - timedelta(days=2),
invalid_at=None,
)
]
@pytest.fixture
def mock_current_episode():
return EpisodicNode(
uuid='episode_1',
content='Current episode content',
valid_at=datetime.now(),
name='Current Episode',
group_id='group_1',
source='message',
source_description='Test source description',
)
@pytest.fixture
def mock_previous_episodes():
return [
EpisodicNode(
uuid='episode_2',
content='Previous episode content',
valid_at=datetime.now() - timedelta(days=1),
name='Previous Episode',
group_id='group_1',
source='message',
source_description='Test source description',
)
]
@pytest.mark.asyncio
async def test_resolve_extracted_edge_no_changes(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
extract_dates_mock = AsyncMock(return_value=(None, None))
get_contradictions_mock = AsyncMock(return_value=[])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid
assert invalidated_edges == []
dedupe_mock.assert_called_once()
extract_dates_mock.assert_called_once()
get_contradictions_mock.assert_called_once()
@pytest.mark.asyncio
async def test_resolve_extracted_edge_with_dates(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
valid_at = datetime.now() - timedelta(days=1)
invalid_at = datetime.now() + timedelta(days=1)
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
extract_dates_mock = AsyncMock(return_value=(valid_at, invalid_at))
get_contradictions_mock = AsyncMock(return_value=[])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
)
assert resolved_edge.valid_at == valid_at
assert resolved_edge.invalid_at == invalid_at
assert resolved_edge.expired_at is not None
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_with_invalidation(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
valid_at = datetime.now() - timedelta(days=1)
mock_extracted_edge.valid_at = valid_at
invalidation_candidate = EntityEdge(
source_node_uuid='source_uuid_4',
target_node_uuid='target_uuid_4',
name='invalidation_candidate',
group_id='group_1',
fact='Invalidation candidate fact',
episodes=['episode_4'],
created_at=datetime.now(),
valid_at=datetime.now() - timedelta(days=2),
invalid_at=None,
)
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
extract_dates_mock = AsyncMock(return_value=(None, None))
get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == invalidation_candidate.uuid
assert invalidated_edges[0].invalid_at == valid_at
assert invalidated_edges[0].expired_at is not None
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -1,368 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from datetime import datetime, timedelta
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.temporal_operations import (
extract_date_strings_from_edge,
prepare_edges_for_invalidation,
prepare_invalidation_context,
)
# Helper function to create test data
def create_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges
existing_edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
existing_edge2 = EntityEdge(
uuid='e2',
source_node_uuid='2',
target_node_uuid='3',
name='LIKES',
fact='Node2 likes Node3',
created_at=now,
group_id='1',
)
new_edge1 = EntityEdge(
uuid='e3',
source_node_uuid='1',
target_node_uuid='3',
name='WORKS_WITH',
fact='Node1 works with Node3',
created_at=now,
group_id='1',
)
new_edge2 = EntityEdge(
uuid='e4',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Node1 dislikes Node2',
created_at=now,
group_id='1',
)
return {
'nodes': [node1, node2, node3],
'existing_edges': [existing_edge1, existing_edge2],
'new_edges': [new_edge1, new_edge2],
}
def test_prepare_edges_for_invalidation_basic():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
test_data['existing_edges'], test_data['new_edges'], test_data['nodes']
)
assert len(existing_edges_pending_invalidation) == 2
assert len(new_edges_with_nodes) == 2
# Check if the edges are correctly associated with nodes
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
assert isinstance(edge_with_nodes[0], EntityNode)
assert isinstance(edge_with_nodes[1], EntityEdge)
assert isinstance(edge_with_nodes[2], EntityNode)
def test_prepare_edges_for_invalidation_no_existing_edges():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
[], test_data['new_edges'], test_data['nodes']
)
assert len(existing_edges_pending_invalidation) == 0
assert len(new_edges_with_nodes) == 2
def test_prepare_edges_for_invalidation_no_new_edges():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
test_data['existing_edges'], [], test_data['nodes']
)
assert len(existing_edges_pending_invalidation) == 2
assert len(new_edges_with_nodes) == 0
def test_prepare_edges_for_invalidation_missing_nodes():
test_data = create_test_data()
# Remove one node to simulate a missing node scenario
nodes = test_data['nodes'][:-1]
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
test_data['existing_edges'], test_data['new_edges'], nodes
)
assert len(existing_edges_pending_invalidation) == 1
assert len(new_edges_with_nodes) == 1
def test_prepare_invalidation_context():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges
edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
edge2 = EntityEdge(
uuid='e2',
source_node_uuid='2',
target_node_uuid='3',
name='LIKES',
fact='Node2 likes Node3',
created_at=now,
group_id='1',
)
# Create NodeEdgeNodeTriplet objects
existing_edge = (node1, edge1, node2)
new_edge = (node2, edge2, node3)
# Prepare test input
existing_edges = [existing_edge]
new_edges = [new_edge]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode 1',
content='This is the content of previous episode 1.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode 1 for unit testing',
group_id='1',
),
EpisodicNode(
name='Previous Episode 2',
content='This is the content of previous episode 2.',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source=EpisodeType.message,
source_description='Test previous episode 2 for unit testing',
group_id='1',
),
]
# Call the function
result = prepare_invalidation_context(
existing_edges, new_edges, current_episode, previous_episodes
)
# Assert the result
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 1
assert len(result['new_edges']) == 1
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 2
# Check the format of the existing edge
existing_edge_str = result['existing_edges'][0]
assert edge1.uuid in existing_edge_str
assert node1.name in existing_edge_str
assert edge1.name in existing_edge_str
assert node2.name in existing_edge_str
assert edge1.fact in existing_edge_str
# Check the format of the new edge
new_edge_str = result['new_edges'][0]
assert edge2.uuid in new_edge_str
assert node2.name in new_edge_str
assert edge2.name in new_edge_str
assert node3.name in new_edge_str
assert edge2.fact in new_edge_str
def test_prepare_invalidation_context_empty_input():
now = datetime.now()
current_episode = EpisodicNode(
name='Current Episode',
content='Empty episode',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test empty episode for unit testing',
group_id='1',
)
result = prepare_invalidation_context([], [], current_episode, [])
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 0
assert len(result['new_edges']) == 0
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 0
def test_prepare_invalidation_context_sorting():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
# Create edges with different timestamps
edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
edge2 = EntityEdge(
uuid='e2',
source_node_uuid='2',
target_node_uuid='1',
name='LIKES',
fact='Node2 likes Node1',
created_at=now + timedelta(hours=1),
group_id='1',
)
edge_with_nodes1 = (node1, edge1, node2)
edge_with_nodes2 = (node2, edge2, node1)
# Prepare test input
existing_edges = [edge_with_nodes1, edge_with_nodes2]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='This is the content of a previous episode.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
),
]
# Call the function
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
# Assert the result
assert len(result['existing_edges']) == 2
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 1
assert result['previous_episodes'][0] == previous_episodes[0].content
class TestExtractDateStringsFromEdge(unittest.TestCase):
def generate_entity_edge(self, valid_at, invalid_at):
return EntityEdge(
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=datetime.now(),
valid_at=valid_at,
invalid_at=invalid_at,
group_id='1',
)
def test_both_dates_present(self):
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0))
result = extract_date_strings_from_edge(edge)
expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)'
self.assertEqual(result, expected)
def test_only_valid_at_present(self):
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None)
result = extract_date_strings_from_edge(edge)
expected = 'Start Date: 2024-01-01T12:00:00'
self.assertEqual(result, expected)
def test_only_invalid_at_present(self):
edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0))
result = extract_date_strings_from_edge(edge)
expected = ' (End Date: 2024-01-02T12:00:00)'
self.assertEqual(result, expected)
def test_no_dates_present(self):
edge = self.generate_entity_edge(None, None)
result = extract_date_strings_from_edge(edge)
expected = ''
self.assertEqual(result, expected)
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -19,12 +19,14 @@ from datetime import datetime, timedelta
import pytest
from dotenv import load_dotenv
from pytz import UTC
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.temporal_operations import (
invalidate_edges,
extract_edge_dates,
get_edge_contradictions,
)
load_dotenv()
@ -43,31 +45,26 @@ def setup_llm_client():
def create_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now)
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now)
# Create edges
edge1 = EntityEdge(
existing_edge = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='LIKES',
fact='Alice likes Bob',
created_at=now - timedelta(days=1),
group_id='1',
)
edge2 = EntityEdge(
new_edge = EntityEdge(
uuid='e2',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
created_at=now,
group_id='1',
)
existing_edge = (node1, edge1, node2)
new_edge = (node1, edge2, node2)
# Create current episode
current_episode = EpisodicNode(
name='Current Episode',
@ -97,46 +94,40 @@ def create_test_data():
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges():
async def test_get_edge_contradictions():
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [existing_edge])
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge[1].uuid
assert invalidated_edges[0].expired_at is not None
assert invalidated_edges[0].uuid == existing_edge.uuid
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_invalidation():
existing_edge, _, current_episode, previous_episodes = create_test_data()
async def test_get_edge_contradictions_no_contradictions():
_, new_edge, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [])
assert len(invalidated_edges) == 0
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_multiple_existing():
existing_edge1, new_edge = create_test_data()
existing_edge2, _ = create_test_data()
existing_edge2[1].uuid = 'e3'
existing_edge2[1].name = 'KNOWS'
existing_edge2[1].fact = 'Alice knows Bob'
async def test_get_edge_contradictions_multiple_existing():
existing_edge1, new_edge, _, _ = create_test_data()
existing_edge2, _, _, _ = create_test_data()
existing_edge2.uuid = 'e3'
existing_edge2.name = 'KNOWS'
existing_edge2.fact = 'Alice knows Bob'
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge1, existing_edge2], [new_edge]
invalidated_edges = await get_edge_contradictions(
setup_llm_client(), new_edge, [existing_edge1, existing_edge2]
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge1[1].uuid
assert invalidated_edges[0].expired_at is not None
assert invalidated_edges[0].uuid == existing_edge1.uuid
# Helper function to create more complex test data
@ -152,7 +143,7 @@ def create_complex_test_data():
)
# Create edges
edge1 = EntityEdge(
existing_edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
@ -161,7 +152,7 @@ def create_complex_test_data():
group_id='1',
created_at=now - timedelta(days=5),
)
edge2 = EntityEdge(
existing_edge2 = EntityEdge(
uuid='e2',
source_node_uuid='1',
target_node_uuid='3',
@ -170,7 +161,7 @@ def create_complex_test_data():
group_id='1',
created_at=now - timedelta(days=3),
)
edge3 = EntityEdge(
existing_edge3 = EntityEdge(
uuid='e3',
source_node_uuid='2',
target_node_uuid='4',
@ -180,10 +171,6 @@ def create_complex_test_data():
created_at=now - timedelta(days=2),
)
existing_edge1 = (node1, edge1, node2)
existing_edge2 = (node1, edge2, node3)
existing_edge3 = (node2, edge3, node4)
return [existing_edge1, existing_edge2, existing_edge3], [
node1,
node2,
@ -198,118 +185,61 @@ async def test_invalidate_edges_complex():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that contradicts an existing one
new_edge = (
nodes[0],
EntityEdge(
uuid='e4',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
group_id='1',
created_at=datetime.now(),
),
nodes[1],
new_edge = EntityEdge(
uuid='e4',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
group_id='1',
created_at=datetime.now(),
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == 'e1'
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_temporal_update():
async def test_get_edge_contradictions_temporal_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that updates an existing one with new information
new_edge = (
nodes[1],
EntityEdge(
uuid='e5',
source_node_uuid='2',
target_node_uuid='4',
name='LEFT_JOB',
fact='Bob left his job at Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
new_edge = EntityEdge(
uuid='e5',
source_node_uuid='2',
target_node_uuid='4',
name='LEFT_JOB',
fact='Bob no longer works at at Company XYZ',
group_id='1',
created_at=datetime.now(),
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == 'e3'
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_multiple_invalidations():
existing_edges, nodes = create_complex_test_data()
# Create new edges that invalidate multiple existing edges
new_edge1 = (
nodes[0],
EntityEdge(
uuid='e6',
source_node_uuid='1',
target_node_uuid='2',
name='ENEMIES_WITH',
fact='Alice and Bob are now enemies',
group_id='1',
created_at=datetime.now(),
),
nodes[1],
)
new_edge2 = (
nodes[0],
EntityEdge(
uuid='e7',
source_node_uuid='1',
target_node_uuid='3',
name='ENDED_FRIENDSHIP',
fact='Alice ended her friendship with Charlie',
group_id='1',
created_at=datetime.now(),
),
nodes[2],
)
invalidated_edges = await invalidate_edges(
setup_llm_client(), existing_edges, [new_edge1, new_edge2]
)
assert len(invalidated_edges) == 2
assert set(edge.uuid for edge in invalidated_edges) == {'e1', 'e2'}
for edge in invalidated_edges:
assert edge.expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_effect():
async def test_get_edge_contradictions_no_effect():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that doesn't invalidate any existing edges
new_edge = (
nodes[2],
EntityEdge(
uuid='e8',
source_node_uuid='3',
target_node_uuid='4',
name='APPLIED_TO',
fact='Charlie applied to Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
new_edge = EntityEdge(
uuid='e8',
source_node_uuid='3',
target_node_uuid='4',
name='APPLIED_TO',
fact='Charlie applied to Company XYZ',
group_id='1',
created_at=datetime.now(),
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 0
@ -320,28 +250,82 @@ async def test_invalidate_edges_partial_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that partially updates an existing one
new_edge = (
nodes[1],
EntityEdge(
uuid='e9',
source_node_uuid='2',
target_node_uuid='4',
name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
new_edge = EntityEdge(
uuid='e9',
source_node_uuid='2',
target_node_uuid='4',
name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=datetime.now(),
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
now = datetime.now(UTC)
previous_episodes = [
EpisodicNode(
name='Previous Episode 1',
content='Bob: I work at XYZ company',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
),
EpisodicNode(
name='Previous Episode 2',
content="Alice: That's really cool!",
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
),
]
episode = EpisodicNode(
name='Previous Episode',
content='Bob: It was cool, but I no longer work at company XYZ',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
)
return episode, previous_episodes
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_empty_inputs():
invalidated_edges = await invalidate_edges(setup_llm_client(), [], [])
async def test_extract_edge_dates():
episode, previous_episodes = create_data_for_temporal_extraction()
assert len(invalidated_edges) == 0
# Create a new edge that partially updates an existing one
new_edge = EntityEdge(
uuid='e9',
source_node_uuid='2',
target_node_uuid='4',
name='LEFT_JOB',
fact='Bob no longer works at Company XYZ',
group_id='1',
created_at=datetime.now(UTC),
)
valid_at, invalid_at = await extract_edge_dates(
setup_llm_client(), new_edge, episode, previous_episodes
)
assert valid_at == episode.valid_at
assert invalid_at is None
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])