parent
377225eec5
commit
e15c872900
8 changed files with 380 additions and 653 deletions
|
|
@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|||
role='user',
|
||||
content=f"""
|
||||
Edge:
|
||||
Edge Name: {context['edge_name']}
|
||||
Fact: {context['edge_fact']}
|
||||
|
||||
Current Episode: {context['current_episode']}
|
||||
|
|
@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|||
Guidelines:
|
||||
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
|
||||
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
|
||||
3. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
|
||||
4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
|
||||
5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
|
||||
6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
|
||||
7. If only a year is mentioned, use January 1st of that year at 00:00:00.
|
||||
3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
|
||||
4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
|
||||
5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
|
||||
6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
|
||||
7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
|
||||
8. If only a year is mentioned, use January 1st of that year at 00:00:00.
|
||||
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
|
||||
Respond with a JSON object:
|
||||
{{
|
||||
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
|
||||
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
|
||||
"explanation": "Brief explanation of why these dates were chosen or why they were set to null"
|
||||
"valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
||||
"invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
||||
}}
|
||||
""",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
|
||||
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
|
||||
|
||||
Existing Edges:
|
||||
{context['existing_edges']}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from .graph_data_operations import (
|
|||
retrieve_episodes,
|
||||
)
|
||||
from .node_operations import extract_nodes
|
||||
from .temporal_operations import invalidate_edges
|
||||
|
||||
__all__ = [
|
||||
'extract_edges',
|
||||
|
|
@ -12,5 +11,4 @@ __all__ = [
|
|||
'extract_nodes',
|
||||
'clear_data',
|
||||
'retrieve_episodes',
|
||||
'invalidate_edges',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -122,12 +122,6 @@ async def extract_edges(
|
|||
return edges
|
||||
|
||||
|
||||
def create_edge_identifier(
|
||||
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
||||
) -> str:
|
||||
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
||||
|
||||
|
||||
async def dedupe_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
|
|
@ -251,11 +245,11 @@ async def resolve_extracted_edge(
|
|||
if (
|
||||
edge.invalid_at is not None
|
||||
and resolved_edge.valid_at is not None
|
||||
and edge.invalid_at < resolved_edge.valid_at
|
||||
and edge.invalid_at <= resolved_edge.valid_at
|
||||
) or (
|
||||
edge.valid_at is not None
|
||||
and resolved_edge.invalid_at is not None
|
||||
and resolved_edge.invalid_at < edge.valid_at
|
||||
and resolved_edge.invalid_at <= edge.valid_at
|
||||
):
|
||||
continue
|
||||
# New edge invalidates edge
|
||||
|
|
|
|||
|
|
@ -21,129 +21,11 @@ from typing import List
|
|||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
|
||||
|
||||
|
||||
def extract_node_and_edge_triplets(
|
||||
edges: list[EntityEdge], nodes: list[EntityNode]
|
||||
) -> list[NodeEdgeNodeTriplet]:
|
||||
return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
|
||||
|
||||
|
||||
def extract_node_edge_node_triplet(
|
||||
edge: EntityEdge, nodes: list[EntityNode]
|
||||
) -> NodeEdgeNodeTriplet:
|
||||
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
||||
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
||||
if not source_node or not target_node:
|
||||
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
|
||||
return (source_node, edge, target_node)
|
||||
|
||||
|
||||
def prepare_edges_for_invalidation(
|
||||
existing_edges: list[EntityEdge],
|
||||
new_edges: list[EntityEdge],
|
||||
nodes: list[EntityNode],
|
||||
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
||||
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
|
||||
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
|
||||
|
||||
for edge_list, result_list in [
|
||||
(existing_edges, existing_edges_pending_invalidation),
|
||||
(new_edges, new_edges_with_nodes),
|
||||
]:
|
||||
for edge in edge_list:
|
||||
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
||||
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
||||
|
||||
if source_node and target_node:
|
||||
result_list.append((source_node, edge, target_node))
|
||||
|
||||
return existing_edges_pending_invalidation, new_edges_with_nodes
|
||||
|
||||
|
||||
async def invalidate_edges(
|
||||
llm_client: LLMClient,
|
||||
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
|
||||
new_edges: list[NodeEdgeNodeTriplet],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityEdge]:
|
||||
invalidated_edges = [] # TODO: this is not yet used?
|
||||
|
||||
context = prepare_invalidation_context(
|
||||
existing_edges_pending_invalidation,
|
||||
new_edges,
|
||||
current_episode,
|
||||
previous_episodes,
|
||||
)
|
||||
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
||||
|
||||
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
||||
invalidated_edges = process_edge_invalidation_llm_response(
|
||||
edges_to_invalidate, existing_edges_pending_invalidation
|
||||
)
|
||||
|
||||
return invalidated_edges
|
||||
|
||||
|
||||
def extract_date_strings_from_edge(edge: EntityEdge) -> str:
|
||||
start = edge.valid_at
|
||||
end = edge.invalid_at
|
||||
date_string = f'Start Date: {start.isoformat()}' if start else ''
|
||||
if end:
|
||||
date_string += f' (End Date: {end.isoformat()})'
|
||||
return date_string
|
||||
|
||||
|
||||
def prepare_invalidation_context(
|
||||
existing_edges: list[NodeEdgeNodeTriplet],
|
||||
new_edges: list[NodeEdgeNodeTriplet],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> dict:
|
||||
return {
|
||||
'existing_edges': [
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
||||
for source_node, edge, target_node in sorted(
|
||||
existing_edges, key=lambda x: (x[1].created_at), reverse=True
|
||||
)
|
||||
],
|
||||
'new_edges': [
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
||||
for source_node, edge, target_node in sorted(
|
||||
new_edges, key=lambda x: (x[1].created_at), reverse=True
|
||||
)
|
||||
],
|
||||
'current_episode': current_episode.content,
|
||||
'previous_episodes': [episode.content for episode in previous_episodes],
|
||||
}
|
||||
|
||||
|
||||
def process_edge_invalidation_llm_response(
|
||||
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
|
||||
) -> List[EntityEdge]:
|
||||
invalidated_edges = []
|
||||
for edge_to_invalidate in edges_to_invalidate:
|
||||
edge_uuid = edge_to_invalidate['edge_uuid']
|
||||
edge_to_update = next(
|
||||
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
|
||||
None,
|
||||
)
|
||||
if edge_to_update:
|
||||
edge_to_update.expired_at = datetime.now()
|
||||
edge_to_update.fact = edge_to_invalidate['fact']
|
||||
invalidated_edges.append(edge_to_update)
|
||||
logger.info(
|
||||
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
|
||||
)
|
||||
return invalidated_edges
|
||||
|
||||
|
||||
async def extract_edge_dates(
|
||||
llm_client: LLMClient,
|
||||
|
|
@ -152,7 +34,6 @@ async def extract_edge_dates(
|
|||
previous_episodes: List[EpisodicNode],
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
context = {
|
||||
'edge_name': edge.name,
|
||||
'edge_fact': edge.fact,
|
||||
'current_episode': current_episode.content,
|
||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||
|
|
@ -162,25 +43,22 @@ async def extract_edge_dates(
|
|||
|
||||
valid_at = llm_response.get('valid_at')
|
||||
invalid_at = llm_response.get('invalid_at')
|
||||
explanation = llm_response.get('explanation', '')
|
||||
|
||||
valid_at_datetime = None
|
||||
invalid_at_datetime = None
|
||||
|
||||
if valid_at and valid_at != '':
|
||||
if valid_at:
|
||||
try:
|
||||
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
|
||||
except ValueError as e:
|
||||
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
||||
|
||||
if invalid_at and invalid_at != '':
|
||||
if invalid_at:
|
||||
try:
|
||||
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
||||
except ValueError as e:
|
||||
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
||||
|
||||
logger.info(f'Edge date extraction explanation: {explanation}')
|
||||
|
||||
return valid_at_datetime, invalid_at_datetime
|
||||
|
||||
|
||||
|
|
|
|||
242
tests/utils/maintenance/test_edge_operations.py
Normal file
242
tests/utils/maintenance/test_edge_operations.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import EpisodicNode
|
||||
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_extracted_edge():
|
||||
return EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='Test fact',
|
||||
episodes=['episode_1'],
|
||||
created_at=datetime.now(),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_related_edges():
|
||||
return [
|
||||
EntityEdge(
|
||||
source_node_uuid='source_uuid_2',
|
||||
target_node_uuid='target_uuid_2',
|
||||
name='related_edge',
|
||||
group_id='group_1',
|
||||
fact='Related fact',
|
||||
episodes=['episode_2'],
|
||||
created_at=datetime.now() - timedelta(days=1),
|
||||
valid_at=datetime.now() - timedelta(days=1),
|
||||
invalid_at=None,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_existing_edges():
|
||||
return [
|
||||
EntityEdge(
|
||||
source_node_uuid='source_uuid_3',
|
||||
target_node_uuid='target_uuid_3',
|
||||
name='existing_edge',
|
||||
group_id='group_1',
|
||||
fact='Existing fact',
|
||||
episodes=['episode_3'],
|
||||
created_at=datetime.now() - timedelta(days=2),
|
||||
valid_at=datetime.now() - timedelta(days=2),
|
||||
invalid_at=None,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_episode():
|
||||
return EpisodicNode(
|
||||
uuid='episode_1',
|
||||
content='Current episode content',
|
||||
valid_at=datetime.now(),
|
||||
name='Current Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='Test source description',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_previous_episodes():
|
||||
return [
|
||||
EpisodicNode(
|
||||
uuid='episode_2',
|
||||
content='Previous episode content',
|
||||
valid_at=datetime.now() - timedelta(days=1),
|
||||
name='Previous Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='Test source description',
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_no_changes(
|
||||
mock_llm_client,
|
||||
mock_extracted_edge,
|
||||
mock_related_edges,
|
||||
mock_existing_edges,
|
||||
mock_current_episode,
|
||||
mock_previous_episodes,
|
||||
monkeypatch: MonkeyPatch,
|
||||
):
|
||||
# Mock the function calls
|
||||
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||
extract_dates_mock = AsyncMock(return_value=(None, None))
|
||||
get_contradictions_mock = AsyncMock(return_value=[])
|
||||
|
||||
# Patch the function calls
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
||||
get_contradictions_mock,
|
||||
)
|
||||
|
||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
mock_extracted_edge,
|
||||
mock_related_edges,
|
||||
mock_existing_edges,
|
||||
mock_current_episode,
|
||||
mock_previous_episodes,
|
||||
)
|
||||
|
||||
assert resolved_edge.uuid == mock_extracted_edge.uuid
|
||||
assert invalidated_edges == []
|
||||
dedupe_mock.assert_called_once()
|
||||
extract_dates_mock.assert_called_once()
|
||||
get_contradictions_mock.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_with_dates(
|
||||
mock_llm_client,
|
||||
mock_extracted_edge,
|
||||
mock_related_edges,
|
||||
mock_existing_edges,
|
||||
mock_current_episode,
|
||||
mock_previous_episodes,
|
||||
monkeypatch: MonkeyPatch,
|
||||
):
|
||||
valid_at = datetime.now() - timedelta(days=1)
|
||||
invalid_at = datetime.now() + timedelta(days=1)
|
||||
|
||||
# Mock the function calls
|
||||
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||
extract_dates_mock = AsyncMock(return_value=(valid_at, invalid_at))
|
||||
get_contradictions_mock = AsyncMock(return_value=[])
|
||||
|
||||
# Patch the function calls
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
||||
get_contradictions_mock,
|
||||
)
|
||||
|
||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
mock_extracted_edge,
|
||||
mock_related_edges,
|
||||
mock_existing_edges,
|
||||
mock_current_episode,
|
||||
mock_previous_episodes,
|
||||
)
|
||||
|
||||
assert resolved_edge.valid_at == valid_at
|
||||
assert resolved_edge.invalid_at == invalid_at
|
||||
assert resolved_edge.expired_at is not None
|
||||
assert invalidated_edges == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_with_invalidation(
|
||||
mock_llm_client,
|
||||
mock_extracted_edge,
|
||||
mock_related_edges,
|
||||
mock_existing_edges,
|
||||
mock_current_episode,
|
||||
mock_previous_episodes,
|
||||
monkeypatch: MonkeyPatch,
|
||||
):
|
||||
valid_at = datetime.now() - timedelta(days=1)
|
||||
mock_extracted_edge.valid_at = valid_at
|
||||
|
||||
invalidation_candidate = EntityEdge(
|
||||
source_node_uuid='source_uuid_4',
|
||||
target_node_uuid='target_uuid_4',
|
||||
name='invalidation_candidate',
|
||||
group_id='group_1',
|
||||
fact='Invalidation candidate fact',
|
||||
episodes=['episode_4'],
|
||||
created_at=datetime.now(),
|
||||
valid_at=datetime.now() - timedelta(days=2),
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
# Mock the function calls
|
||||
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||
extract_dates_mock = AsyncMock(return_value=(None, None))
|
||||
get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
|
||||
|
||||
# Patch the function calls
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
||||
get_contradictions_mock,
|
||||
)
|
||||
|
||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
mock_extracted_edge,
|
||||
mock_related_edges,
|
||||
mock_existing_edges,
|
||||
mock_current_episode,
|
||||
mock_previous_episodes,
|
||||
)
|
||||
|
||||
assert resolved_edge.uuid == mock_extracted_edge.uuid
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == invalidation_candidate.uuid
|
||||
assert invalidated_edges[0].invalid_at == valid_at
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -1,368 +0,0 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_date_strings_from_edge,
|
||||
prepare_edges_for_invalidation,
|
||||
prepare_invalidation_context,
|
||||
)
|
||||
|
||||
|
||||
# Helper function to create test data
|
||||
def create_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
||||
|
||||
# Create edges
|
||||
existing_edge1 = EntityEdge(
|
||||
uuid='e1',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
existing_edge2 = EntityEdge(
|
||||
uuid='e2',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='3',
|
||||
name='LIKES',
|
||||
fact='Node2 likes Node3',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
new_edge1 = EntityEdge(
|
||||
uuid='e3',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='3',
|
||||
name='WORKS_WITH',
|
||||
fact='Node1 works with Node3',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
new_edge2 = EntityEdge(
|
||||
uuid='e4',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='DISLIKES',
|
||||
fact='Node1 dislikes Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
return {
|
||||
'nodes': [node1, node2, node3],
|
||||
'existing_edges': [existing_edge1, existing_edge2],
|
||||
'new_edges': [new_edge1, new_edge2],
|
||||
}
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_basic():
|
||||
test_data = create_test_data()
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
||||
test_data['existing_edges'], test_data['new_edges'], test_data['nodes']
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 2
|
||||
assert len(new_edges_with_nodes) == 2
|
||||
|
||||
# Check if the edges are correctly associated with nodes
|
||||
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
|
||||
assert isinstance(edge_with_nodes[0], EntityNode)
|
||||
assert isinstance(edge_with_nodes[1], EntityEdge)
|
||||
assert isinstance(edge_with_nodes[2], EntityNode)
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_no_existing_edges():
|
||||
test_data = create_test_data()
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
||||
[], test_data['new_edges'], test_data['nodes']
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 0
|
||||
assert len(new_edges_with_nodes) == 2
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_no_new_edges():
|
||||
test_data = create_test_data()
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
||||
test_data['existing_edges'], [], test_data['nodes']
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 2
|
||||
assert len(new_edges_with_nodes) == 0
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_missing_nodes():
|
||||
test_data = create_test_data()
|
||||
|
||||
# Remove one node to simulate a missing node scenario
|
||||
nodes = test_data['nodes'][:-1]
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
||||
test_data['existing_edges'], test_data['new_edges'], nodes
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 1
|
||||
assert len(new_edges_with_nodes) == 1
|
||||
|
||||
|
||||
def test_prepare_invalidation_context():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
uuid='e1',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid='e2',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='3',
|
||||
name='LIKES',
|
||||
fact='Node2 likes Node3',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
# Create NodeEdgeNodeTriplet objects
|
||||
existing_edge = (node1, edge1, node2)
|
||||
new_edge = (node2, edge2, node3)
|
||||
|
||||
# Prepare test input
|
||||
existing_edges = [existing_edge]
|
||||
new_edges = [new_edge]
|
||||
|
||||
# Create a current episode and previous episodes
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
content='This is the current episode content.',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name='Previous Episode 1',
|
||||
content='This is the content of previous episode 1.',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode 1 for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
EpisodicNode(
|
||||
name='Previous Episode 2',
|
||||
content='This is the content of previous episode 2.',
|
||||
created_at=now - timedelta(days=2),
|
||||
valid_at=now - timedelta(days=2),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode 2 for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = prepare_invalidation_context(
|
||||
existing_edges, new_edges, current_episode, previous_episodes
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, dict)
|
||||
assert 'existing_edges' in result
|
||||
assert 'new_edges' in result
|
||||
assert 'current_episode' in result
|
||||
assert 'previous_episodes' in result
|
||||
assert len(result['existing_edges']) == 1
|
||||
assert len(result['new_edges']) == 1
|
||||
assert result['current_episode'] == current_episode.content
|
||||
assert len(result['previous_episodes']) == 2
|
||||
|
||||
# Check the format of the existing edge
|
||||
existing_edge_str = result['existing_edges'][0]
|
||||
assert edge1.uuid in existing_edge_str
|
||||
assert node1.name in existing_edge_str
|
||||
assert edge1.name in existing_edge_str
|
||||
assert node2.name in existing_edge_str
|
||||
assert edge1.fact in existing_edge_str
|
||||
|
||||
# Check the format of the new edge
|
||||
new_edge_str = result['new_edges'][0]
|
||||
assert edge2.uuid in new_edge_str
|
||||
assert node2.name in new_edge_str
|
||||
assert edge2.name in new_edge_str
|
||||
assert node3.name in new_edge_str
|
||||
assert edge2.fact in new_edge_str
|
||||
|
||||
|
||||
def test_prepare_invalidation_context_empty_input():
|
||||
now = datetime.now()
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
content='Empty episode',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test empty episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
result = prepare_invalidation_context([], [], current_episode, [])
|
||||
assert isinstance(result, dict)
|
||||
assert 'existing_edges' in result
|
||||
assert 'new_edges' in result
|
||||
assert 'current_episode' in result
|
||||
assert 'previous_episodes' in result
|
||||
assert len(result['existing_edges']) == 0
|
||||
assert len(result['new_edges']) == 0
|
||||
assert result['current_episode'] == current_episode.content
|
||||
assert len(result['previous_episodes']) == 0
|
||||
|
||||
|
||||
def test_prepare_invalidation_context_sorting():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
||||
|
||||
# Create edges with different timestamps
|
||||
edge1 = EntityEdge(
|
||||
uuid='e1',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid='e2',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='1',
|
||||
name='LIKES',
|
||||
fact='Node2 likes Node1',
|
||||
created_at=now + timedelta(hours=1),
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
edge_with_nodes1 = (node1, edge1, node2)
|
||||
edge_with_nodes2 = (node2, edge2, node1)
|
||||
|
||||
# Prepare test input
|
||||
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
||||
|
||||
# Create a current episode and previous episodes
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
content='This is the current episode content.',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name='Previous Episode',
|
||||
content='This is the content of a previous episode.',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
|
||||
|
||||
# Assert the result
|
||||
assert len(result['existing_edges']) == 2
|
||||
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
|
||||
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
|
||||
assert result['current_episode'] == current_episode.content
|
||||
assert len(result['previous_episodes']) == 1
|
||||
assert result['previous_episodes'][0] == previous_episodes[0].content
|
||||
|
||||
|
||||
class TestExtractDateStringsFromEdge(unittest.TestCase):
|
||||
def generate_entity_edge(self, valid_at, invalid_at):
|
||||
return EntityEdge(
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=datetime.now(),
|
||||
valid_at=valid_at,
|
||||
invalid_at=invalid_at,
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
def test_both_dates_present(self):
|
||||
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0))
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)'
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_only_valid_at_present(self):
|
||||
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None)
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = 'Start Date: 2024-01-01T12:00:00'
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_only_invalid_at_present(self):
|
||||
edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0))
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = ' (End Date: 2024-01-02T12:00:00)'
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_no_dates_present(self):
|
||||
edge = self.generate_entity_edge(None, None)
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = ''
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -19,12 +19,14 @@ from datetime import datetime, timedelta
|
|||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from pytz import UTC
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
invalidate_edges,
|
||||
extract_edge_dates,
|
||||
get_edge_contradictions,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -43,31 +45,26 @@ def setup_llm_client():
|
|||
def create_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now)
|
||||
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now)
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
existing_edge = EntityEdge(
|
||||
uuid='e1',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='LIKES',
|
||||
fact='Alice likes Bob',
|
||||
created_at=now - timedelta(days=1),
|
||||
group_id='1',
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
new_edge = EntityEdge(
|
||||
uuid='e2',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='DISLIKES',
|
||||
fact='Alice dislikes Bob',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
existing_edge = (node1, edge1, node2)
|
||||
new_edge = (node1, edge2, node2)
|
||||
|
||||
# Create current episode
|
||||
current_episode = EpisodicNode(
|
||||
name='Current Episode',
|
||||
|
|
@ -97,46 +94,40 @@ def create_test_data():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges():
|
||||
async def test_get_edge_contradictions():
|
||||
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
|
||||
)
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [existing_edge])
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == existing_edge[1].uuid
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
assert invalidated_edges[0].uuid == existing_edge.uuid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_no_invalidation():
|
||||
existing_edge, _, current_episode, previous_episodes = create_test_data()
|
||||
async def test_get_edge_contradictions_no_contradictions():
|
||||
_, new_edge, current_episode, previous_episodes = create_test_data()
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
|
||||
)
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [])
|
||||
|
||||
assert len(invalidated_edges) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_multiple_existing():
|
||||
existing_edge1, new_edge = create_test_data()
|
||||
existing_edge2, _ = create_test_data()
|
||||
existing_edge2[1].uuid = 'e3'
|
||||
existing_edge2[1].name = 'KNOWS'
|
||||
existing_edge2[1].fact = 'Alice knows Bob'
|
||||
async def test_get_edge_contradictions_multiple_existing():
|
||||
existing_edge1, new_edge, _, _ = create_test_data()
|
||||
existing_edge2, _, _, _ = create_test_data()
|
||||
existing_edge2.uuid = 'e3'
|
||||
existing_edge2.name = 'KNOWS'
|
||||
existing_edge2.fact = 'Alice knows Bob'
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), [existing_edge1, existing_edge2], [new_edge]
|
||||
invalidated_edges = await get_edge_contradictions(
|
||||
setup_llm_client(), new_edge, [existing_edge1, existing_edge2]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == existing_edge1[1].uuid
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
assert invalidated_edges[0].uuid == existing_edge1.uuid
|
||||
|
||||
|
||||
# Helper function to create more complex test data
|
||||
|
|
@ -152,7 +143,7 @@ def create_complex_test_data():
|
|||
)
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
existing_edge1 = EntityEdge(
|
||||
uuid='e1',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
|
|
@ -161,7 +152,7 @@ def create_complex_test_data():
|
|||
group_id='1',
|
||||
created_at=now - timedelta(days=5),
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
existing_edge2 = EntityEdge(
|
||||
uuid='e2',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='3',
|
||||
|
|
@ -170,7 +161,7 @@ def create_complex_test_data():
|
|||
group_id='1',
|
||||
created_at=now - timedelta(days=3),
|
||||
)
|
||||
edge3 = EntityEdge(
|
||||
existing_edge3 = EntityEdge(
|
||||
uuid='e3',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='4',
|
||||
|
|
@ -180,10 +171,6 @@ def create_complex_test_data():
|
|||
created_at=now - timedelta(days=2),
|
||||
)
|
||||
|
||||
existing_edge1 = (node1, edge1, node2)
|
||||
existing_edge2 = (node1, edge2, node3)
|
||||
existing_edge3 = (node2, edge3, node4)
|
||||
|
||||
return [existing_edge1, existing_edge2, existing_edge3], [
|
||||
node1,
|
||||
node2,
|
||||
|
|
@ -198,118 +185,61 @@ async def test_invalidate_edges_complex():
|
|||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that contradicts an existing one
|
||||
new_edge = (
|
||||
nodes[0],
|
||||
EntityEdge(
|
||||
uuid='e4',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='DISLIKES',
|
||||
fact='Alice dislikes Bob',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[1],
|
||||
new_edge = EntityEdge(
|
||||
uuid='e4',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='DISLIKES',
|
||||
fact='Alice dislikes Bob',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == 'e1'
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_temporal_update():
|
||||
async def test_get_edge_contradictions_temporal_update():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that updates an existing one with new information
|
||||
new_edge = (
|
||||
nodes[1],
|
||||
EntityEdge(
|
||||
uuid='e5',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='4',
|
||||
name='LEFT_JOB',
|
||||
fact='Bob left his job at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
new_edge = EntityEdge(
|
||||
uuid='e5',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='4',
|
||||
name='LEFT_JOB',
|
||||
fact='Bob no longer works at at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == 'e3'
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_multiple_invalidations():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create new edges that invalidate multiple existing edges
|
||||
new_edge1 = (
|
||||
nodes[0],
|
||||
EntityEdge(
|
||||
uuid='e6',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='ENEMIES_WITH',
|
||||
fact='Alice and Bob are now enemies',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[1],
|
||||
)
|
||||
new_edge2 = (
|
||||
nodes[0],
|
||||
EntityEdge(
|
||||
uuid='e7',
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='3',
|
||||
name='ENDED_FRIENDSHIP',
|
||||
fact='Alice ended her friendship with Charlie',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[2],
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), existing_edges, [new_edge1, new_edge2]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 2
|
||||
assert set(edge.uuid for edge in invalidated_edges) == {'e1', 'e2'}
|
||||
for edge in invalidated_edges:
|
||||
assert edge.expired_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_no_effect():
|
||||
async def test_get_edge_contradictions_no_effect():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that doesn't invalidate any existing edges
|
||||
new_edge = (
|
||||
nodes[2],
|
||||
EntityEdge(
|
||||
uuid='e8',
|
||||
source_node_uuid='3',
|
||||
target_node_uuid='4',
|
||||
name='APPLIED_TO',
|
||||
fact='Charlie applied to Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
new_edge = EntityEdge(
|
||||
uuid='e8',
|
||||
source_node_uuid='3',
|
||||
target_node_uuid='4',
|
||||
name='APPLIED_TO',
|
||||
fact='Charlie applied to Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
|
||||
assert len(invalidated_edges) == 0
|
||||
|
||||
|
|
@ -320,28 +250,82 @@ async def test_invalidate_edges_partial_update():
|
|||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that partially updates an existing one
|
||||
new_edge = (
|
||||
nodes[1],
|
||||
EntityEdge(
|
||||
uuid='e9',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='4',
|
||||
name='CHANGED_POSITION',
|
||||
fact='Bob changed his position at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
new_edge = EntityEdge(
|
||||
uuid='e9',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='4',
|
||||
name='CHANGED_POSITION',
|
||||
fact='Bob changed his position at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
|
||||
assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
|
||||
|
||||
|
||||
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name='Previous Episode 1',
|
||||
content='Bob: I work at XYZ company',
|
||||
created_at=now - timedelta(days=2),
|
||||
valid_at=now - timedelta(days=2),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
EpisodicNode(
|
||||
name='Previous Episode 2',
|
||||
content="Alice: That's really cool!",
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
]
|
||||
|
||||
episode = EpisodicNode(
|
||||
name='Previous Episode',
|
||||
content='Bob: It was cool, but I no longer work at company XYZ',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
return episode, previous_episodes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_empty_inputs():
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [], [])
|
||||
async def test_extract_edge_dates():
|
||||
episode, previous_episodes = create_data_for_temporal_extraction()
|
||||
|
||||
assert len(invalidated_edges) == 0
|
||||
# Create a new edge that partially updates an existing one
|
||||
new_edge = EntityEdge(
|
||||
uuid='e9',
|
||||
source_node_uuid='2',
|
||||
target_node_uuid='4',
|
||||
name='LEFT_JOB',
|
||||
fact='Bob no longer works at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
valid_at, invalid_at = await extract_edge_dates(
|
||||
setup_llm_client(), new_edge, episode, previous_episodes
|
||||
)
|
||||
|
||||
assert valid_at == episode.valid_at
|
||||
assert invalid_at is None
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue