parent
377225eec5
commit
e15c872900
8 changed files with 380 additions and 653 deletions
|
|
@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Edge:
|
Edge:
|
||||||
Edge Name: {context['edge_name']}
|
|
||||||
Fact: {context['edge_fact']}
|
Fact: {context['edge_fact']}
|
||||||
|
|
||||||
Current Episode: {context['current_episode']}
|
Current Episode: {context['current_episode']}
|
||||||
|
|
@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
|
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
|
||||||
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
|
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
|
||||||
3. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
|
3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
|
||||||
4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
|
4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
|
||||||
5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
|
5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
|
||||||
6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
|
6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
|
||||||
7. If only a year is mentioned, use January 1st of that year at 00:00:00.
|
7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
|
||||||
|
8. If only a year is mentioned, use January 1st of that year at 00:00:00.
|
||||||
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
|
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
|
||||||
Respond with a JSON object:
|
Respond with a JSON object:
|
||||||
{{
|
{{
|
||||||
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
|
"valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
||||||
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
|
"invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
||||||
"explanation": "Brief explanation of why these dates were chosen or why they were set to null"
|
|
||||||
}}
|
}}
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
Message(
|
Message(
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
|
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
|
||||||
|
|
||||||
Existing Edges:
|
Existing Edges:
|
||||||
{context['existing_edges']}
|
{context['existing_edges']}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from .graph_data_operations import (
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from .node_operations import extract_nodes
|
from .node_operations import extract_nodes
|
||||||
from .temporal_operations import invalidate_edges
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'extract_edges',
|
'extract_edges',
|
||||||
|
|
@ -12,5 +11,4 @@ __all__ = [
|
||||||
'extract_nodes',
|
'extract_nodes',
|
||||||
'clear_data',
|
'clear_data',
|
||||||
'retrieve_episodes',
|
'retrieve_episodes',
|
||||||
'invalidate_edges',
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -122,12 +122,6 @@ async def extract_edges(
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
def create_edge_identifier(
|
|
||||||
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
|
||||||
) -> str:
|
|
||||||
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_edges(
|
async def dedupe_extracted_edges(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_edges: list[EntityEdge],
|
extracted_edges: list[EntityEdge],
|
||||||
|
|
@ -251,11 +245,11 @@ async def resolve_extracted_edge(
|
||||||
if (
|
if (
|
||||||
edge.invalid_at is not None
|
edge.invalid_at is not None
|
||||||
and resolved_edge.valid_at is not None
|
and resolved_edge.valid_at is not None
|
||||||
and edge.invalid_at < resolved_edge.valid_at
|
and edge.invalid_at <= resolved_edge.valid_at
|
||||||
) or (
|
) or (
|
||||||
edge.valid_at is not None
|
edge.valid_at is not None
|
||||||
and resolved_edge.invalid_at is not None
|
and resolved_edge.invalid_at is not None
|
||||||
and resolved_edge.invalid_at < edge.valid_at
|
and resolved_edge.invalid_at <= edge.valid_at
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# New edge invalidates edge
|
# New edge invalidates edge
|
||||||
|
|
|
||||||
|
|
@ -21,129 +21,11 @@ from typing import List
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_node_and_edge_triplets(
|
|
||||||
edges: list[EntityEdge], nodes: list[EntityNode]
|
|
||||||
) -> list[NodeEdgeNodeTriplet]:
|
|
||||||
return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_node_edge_node_triplet(
|
|
||||||
edge: EntityEdge, nodes: list[EntityNode]
|
|
||||||
) -> NodeEdgeNodeTriplet:
|
|
||||||
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
|
||||||
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
|
||||||
if not source_node or not target_node:
|
|
||||||
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
|
|
||||||
return (source_node, edge, target_node)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_edges_for_invalidation(
|
|
||||||
existing_edges: list[EntityEdge],
|
|
||||||
new_edges: list[EntityEdge],
|
|
||||||
nodes: list[EntityNode],
|
|
||||||
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
|
||||||
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
|
|
||||||
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
|
|
||||||
|
|
||||||
for edge_list, result_list in [
|
|
||||||
(existing_edges, existing_edges_pending_invalidation),
|
|
||||||
(new_edges, new_edges_with_nodes),
|
|
||||||
]:
|
|
||||||
for edge in edge_list:
|
|
||||||
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
|
||||||
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
|
||||||
|
|
||||||
if source_node and target_node:
|
|
||||||
result_list.append((source_node, edge, target_node))
|
|
||||||
|
|
||||||
return existing_edges_pending_invalidation, new_edges_with_nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def invalidate_edges(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
|
|
||||||
new_edges: list[NodeEdgeNodeTriplet],
|
|
||||||
current_episode: EpisodicNode,
|
|
||||||
previous_episodes: list[EpisodicNode],
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
invalidated_edges = [] # TODO: this is not yet used?
|
|
||||||
|
|
||||||
context = prepare_invalidation_context(
|
|
||||||
existing_edges_pending_invalidation,
|
|
||||||
new_edges,
|
|
||||||
current_episode,
|
|
||||||
previous_episodes,
|
|
||||||
)
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
|
||||||
|
|
||||||
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
|
||||||
invalidated_edges = process_edge_invalidation_llm_response(
|
|
||||||
edges_to_invalidate, existing_edges_pending_invalidation
|
|
||||||
)
|
|
||||||
|
|
||||||
return invalidated_edges
|
|
||||||
|
|
||||||
|
|
||||||
def extract_date_strings_from_edge(edge: EntityEdge) -> str:
|
|
||||||
start = edge.valid_at
|
|
||||||
end = edge.invalid_at
|
|
||||||
date_string = f'Start Date: {start.isoformat()}' if start else ''
|
|
||||||
if end:
|
|
||||||
date_string += f' (End Date: {end.isoformat()})'
|
|
||||||
return date_string
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_invalidation_context(
|
|
||||||
existing_edges: list[NodeEdgeNodeTriplet],
|
|
||||||
new_edges: list[NodeEdgeNodeTriplet],
|
|
||||||
current_episode: EpisodicNode,
|
|
||||||
previous_episodes: list[EpisodicNode],
|
|
||||||
) -> dict:
|
|
||||||
return {
|
|
||||||
'existing_edges': [
|
|
||||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
|
||||||
for source_node, edge, target_node in sorted(
|
|
||||||
existing_edges, key=lambda x: (x[1].created_at), reverse=True
|
|
||||||
)
|
|
||||||
],
|
|
||||||
'new_edges': [
|
|
||||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
|
||||||
for source_node, edge, target_node in sorted(
|
|
||||||
new_edges, key=lambda x: (x[1].created_at), reverse=True
|
|
||||||
)
|
|
||||||
],
|
|
||||||
'current_episode': current_episode.content,
|
|
||||||
'previous_episodes': [episode.content for episode in previous_episodes],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def process_edge_invalidation_llm_response(
|
|
||||||
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
|
|
||||||
) -> List[EntityEdge]:
|
|
||||||
invalidated_edges = []
|
|
||||||
for edge_to_invalidate in edges_to_invalidate:
|
|
||||||
edge_uuid = edge_to_invalidate['edge_uuid']
|
|
||||||
edge_to_update = next(
|
|
||||||
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if edge_to_update:
|
|
||||||
edge_to_update.expired_at = datetime.now()
|
|
||||||
edge_to_update.fact = edge_to_invalidate['fact']
|
|
||||||
invalidated_edges.append(edge_to_update)
|
|
||||||
logger.info(
|
|
||||||
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
|
|
||||||
)
|
|
||||||
return invalidated_edges
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_edge_dates(
|
async def extract_edge_dates(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
|
|
@ -152,7 +34,6 @@ async def extract_edge_dates(
|
||||||
previous_episodes: List[EpisodicNode],
|
previous_episodes: List[EpisodicNode],
|
||||||
) -> tuple[datetime | None, datetime | None]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
context = {
|
context = {
|
||||||
'edge_name': edge.name,
|
|
||||||
'edge_fact': edge.fact,
|
'edge_fact': edge.fact,
|
||||||
'current_episode': current_episode.content,
|
'current_episode': current_episode.content,
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
|
|
@ -162,25 +43,22 @@ async def extract_edge_dates(
|
||||||
|
|
||||||
valid_at = llm_response.get('valid_at')
|
valid_at = llm_response.get('valid_at')
|
||||||
invalid_at = llm_response.get('invalid_at')
|
invalid_at = llm_response.get('invalid_at')
|
||||||
explanation = llm_response.get('explanation', '')
|
|
||||||
|
|
||||||
valid_at_datetime = None
|
valid_at_datetime = None
|
||||||
invalid_at_datetime = None
|
invalid_at_datetime = None
|
||||||
|
|
||||||
if valid_at and valid_at != '':
|
if valid_at:
|
||||||
try:
|
try:
|
||||||
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
|
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
||||||
|
|
||||||
if invalid_at and invalid_at != '':
|
if invalid_at:
|
||||||
try:
|
try:
|
||||||
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
||||||
|
|
||||||
logger.info(f'Edge date extraction explanation: {explanation}')
|
|
||||||
|
|
||||||
return valid_at_datetime, invalid_at_datetime
|
return valid_at_datetime, invalid_at_datetime
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
242
tests/utils/maintenance/test_edge_operations.py
Normal file
242
tests/utils/maintenance/test_edge_operations.py
Normal file
|
|
@ -0,0 +1,242 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from graphiti_core.edges import EntityEdge
|
||||||
|
from graphiti_core.nodes import EpisodicNode
|
||||||
|
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_client():
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_extracted_edge():
|
||||||
|
return EntityEdge(
|
||||||
|
source_node_uuid='source_uuid',
|
||||||
|
target_node_uuid='target_uuid',
|
||||||
|
name='test_edge',
|
||||||
|
group_id='group_1',
|
||||||
|
fact='Test fact',
|
||||||
|
episodes=['episode_1'],
|
||||||
|
created_at=datetime.now(),
|
||||||
|
valid_at=None,
|
||||||
|
invalid_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_related_edges():
|
||||||
|
return [
|
||||||
|
EntityEdge(
|
||||||
|
source_node_uuid='source_uuid_2',
|
||||||
|
target_node_uuid='target_uuid_2',
|
||||||
|
name='related_edge',
|
||||||
|
group_id='group_1',
|
||||||
|
fact='Related fact',
|
||||||
|
episodes=['episode_2'],
|
||||||
|
created_at=datetime.now() - timedelta(days=1),
|
||||||
|
valid_at=datetime.now() - timedelta(days=1),
|
||||||
|
invalid_at=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_existing_edges():
|
||||||
|
return [
|
||||||
|
EntityEdge(
|
||||||
|
source_node_uuid='source_uuid_3',
|
||||||
|
target_node_uuid='target_uuid_3',
|
||||||
|
name='existing_edge',
|
||||||
|
group_id='group_1',
|
||||||
|
fact='Existing fact',
|
||||||
|
episodes=['episode_3'],
|
||||||
|
created_at=datetime.now() - timedelta(days=2),
|
||||||
|
valid_at=datetime.now() - timedelta(days=2),
|
||||||
|
invalid_at=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_current_episode():
|
||||||
|
return EpisodicNode(
|
||||||
|
uuid='episode_1',
|
||||||
|
content='Current episode content',
|
||||||
|
valid_at=datetime.now(),
|
||||||
|
name='Current Episode',
|
||||||
|
group_id='group_1',
|
||||||
|
source='message',
|
||||||
|
source_description='Test source description',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_previous_episodes():
|
||||||
|
return [
|
||||||
|
EpisodicNode(
|
||||||
|
uuid='episode_2',
|
||||||
|
content='Previous episode content',
|
||||||
|
valid_at=datetime.now() - timedelta(days=1),
|
||||||
|
name='Previous Episode',
|
||||||
|
group_id='group_1',
|
||||||
|
source='message',
|
||||||
|
source_description='Test source description',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_extracted_edge_no_changes(
|
||||||
|
mock_llm_client,
|
||||||
|
mock_extracted_edge,
|
||||||
|
mock_related_edges,
|
||||||
|
mock_existing_edges,
|
||||||
|
mock_current_episode,
|
||||||
|
mock_previous_episodes,
|
||||||
|
monkeypatch: MonkeyPatch,
|
||||||
|
):
|
||||||
|
# Mock the function calls
|
||||||
|
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||||
|
extract_dates_mock = AsyncMock(return_value=(None, None))
|
||||||
|
get_contradictions_mock = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
# Patch the function calls
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
||||||
|
get_contradictions_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||||
|
mock_llm_client,
|
||||||
|
mock_extracted_edge,
|
||||||
|
mock_related_edges,
|
||||||
|
mock_existing_edges,
|
||||||
|
mock_current_episode,
|
||||||
|
mock_previous_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved_edge.uuid == mock_extracted_edge.uuid
|
||||||
|
assert invalidated_edges == []
|
||||||
|
dedupe_mock.assert_called_once()
|
||||||
|
extract_dates_mock.assert_called_once()
|
||||||
|
get_contradictions_mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_extracted_edge_with_dates(
|
||||||
|
mock_llm_client,
|
||||||
|
mock_extracted_edge,
|
||||||
|
mock_related_edges,
|
||||||
|
mock_existing_edges,
|
||||||
|
mock_current_episode,
|
||||||
|
mock_previous_episodes,
|
||||||
|
monkeypatch: MonkeyPatch,
|
||||||
|
):
|
||||||
|
valid_at = datetime.now() - timedelta(days=1)
|
||||||
|
invalid_at = datetime.now() + timedelta(days=1)
|
||||||
|
|
||||||
|
# Mock the function calls
|
||||||
|
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||||
|
extract_dates_mock = AsyncMock(return_value=(valid_at, invalid_at))
|
||||||
|
get_contradictions_mock = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
# Patch the function calls
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
||||||
|
get_contradictions_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||||
|
mock_llm_client,
|
||||||
|
mock_extracted_edge,
|
||||||
|
mock_related_edges,
|
||||||
|
mock_existing_edges,
|
||||||
|
mock_current_episode,
|
||||||
|
mock_previous_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved_edge.valid_at == valid_at
|
||||||
|
assert resolved_edge.invalid_at == invalid_at
|
||||||
|
assert resolved_edge.expired_at is not None
|
||||||
|
assert invalidated_edges == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_extracted_edge_with_invalidation(
|
||||||
|
mock_llm_client,
|
||||||
|
mock_extracted_edge,
|
||||||
|
mock_related_edges,
|
||||||
|
mock_existing_edges,
|
||||||
|
mock_current_episode,
|
||||||
|
mock_previous_episodes,
|
||||||
|
monkeypatch: MonkeyPatch,
|
||||||
|
):
|
||||||
|
valid_at = datetime.now() - timedelta(days=1)
|
||||||
|
mock_extracted_edge.valid_at = valid_at
|
||||||
|
|
||||||
|
invalidation_candidate = EntityEdge(
|
||||||
|
source_node_uuid='source_uuid_4',
|
||||||
|
target_node_uuid='target_uuid_4',
|
||||||
|
name='invalidation_candidate',
|
||||||
|
group_id='group_1',
|
||||||
|
fact='Invalidation candidate fact',
|
||||||
|
episodes=['episode_4'],
|
||||||
|
created_at=datetime.now(),
|
||||||
|
valid_at=datetime.now() - timedelta(days=2),
|
||||||
|
invalid_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the function calls
|
||||||
|
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||||
|
extract_dates_mock = AsyncMock(return_value=(None, None))
|
||||||
|
get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
|
||||||
|
|
||||||
|
# Patch the function calls
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
||||||
|
get_contradictions_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||||
|
mock_llm_client,
|
||||||
|
mock_extracted_edge,
|
||||||
|
mock_related_edges,
|
||||||
|
mock_existing_edges,
|
||||||
|
mock_current_episode,
|
||||||
|
mock_previous_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved_edge.uuid == mock_extracted_edge.uuid
|
||||||
|
assert len(invalidated_edges) == 1
|
||||||
|
assert invalidated_edges[0].uuid == invalidation_candidate.uuid
|
||||||
|
assert invalidated_edges[0].invalid_at == valid_at
|
||||||
|
assert invalidated_edges[0].expired_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Run the tests
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__])
|
||||||
|
|
@ -1,368 +0,0 @@
|
||||||
"""
|
|
||||||
Copyright 2024, Zep Software, Inc.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
||||||
extract_date_strings_from_edge,
|
|
||||||
prepare_edges_for_invalidation,
|
|
||||||
prepare_invalidation_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Helper function to create test data
|
|
||||||
def create_test_data():
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
# Create nodes
|
|
||||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
|
|
||||||
# Create edges
|
|
||||||
existing_edge1 = EntityEdge(
|
|
||||||
uuid='e1',
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='2',
|
|
||||||
name='KNOWS',
|
|
||||||
fact='Node1 knows Node2',
|
|
||||||
created_at=now,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
existing_edge2 = EntityEdge(
|
|
||||||
uuid='e2',
|
|
||||||
source_node_uuid='2',
|
|
||||||
target_node_uuid='3',
|
|
||||||
name='LIKES',
|
|
||||||
fact='Node2 likes Node3',
|
|
||||||
created_at=now,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
new_edge1 = EntityEdge(
|
|
||||||
uuid='e3',
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='3',
|
|
||||||
name='WORKS_WITH',
|
|
||||||
fact='Node1 works with Node3',
|
|
||||||
created_at=now,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
new_edge2 = EntityEdge(
|
|
||||||
uuid='e4',
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='2',
|
|
||||||
name='DISLIKES',
|
|
||||||
fact='Node1 dislikes Node2',
|
|
||||||
created_at=now,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'nodes': [node1, node2, node3],
|
|
||||||
'existing_edges': [existing_edge1, existing_edge2],
|
|
||||||
'new_edges': [new_edge1, new_edge2],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_edges_for_invalidation_basic():
|
|
||||||
test_data = create_test_data()
|
|
||||||
|
|
||||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
||||||
test_data['existing_edges'], test_data['new_edges'], test_data['nodes']
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(existing_edges_pending_invalidation) == 2
|
|
||||||
assert len(new_edges_with_nodes) == 2
|
|
||||||
|
|
||||||
# Check if the edges are correctly associated with nodes
|
|
||||||
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
|
|
||||||
assert isinstance(edge_with_nodes[0], EntityNode)
|
|
||||||
assert isinstance(edge_with_nodes[1], EntityEdge)
|
|
||||||
assert isinstance(edge_with_nodes[2], EntityNode)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_edges_for_invalidation_no_existing_edges():
|
|
||||||
test_data = create_test_data()
|
|
||||||
|
|
||||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
||||||
[], test_data['new_edges'], test_data['nodes']
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(existing_edges_pending_invalidation) == 0
|
|
||||||
assert len(new_edges_with_nodes) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_edges_for_invalidation_no_new_edges():
|
|
||||||
test_data = create_test_data()
|
|
||||||
|
|
||||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
||||||
test_data['existing_edges'], [], test_data['nodes']
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(existing_edges_pending_invalidation) == 2
|
|
||||||
assert len(new_edges_with_nodes) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_edges_for_invalidation_missing_nodes():
|
|
||||||
test_data = create_test_data()
|
|
||||||
|
|
||||||
# Remove one node to simulate a missing node scenario
|
|
||||||
nodes = test_data['nodes'][:-1]
|
|
||||||
|
|
||||||
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
||||||
test_data['existing_edges'], test_data['new_edges'], nodes
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(existing_edges_pending_invalidation) == 1
|
|
||||||
assert len(new_edges_with_nodes) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_invalidation_context():
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
# Create nodes
|
|
||||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
|
|
||||||
# Create edges
|
|
||||||
edge1 = EntityEdge(
|
|
||||||
uuid='e1',
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='2',
|
|
||||||
name='KNOWS',
|
|
||||||
fact='Node1 knows Node2',
|
|
||||||
created_at=now,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
edge2 = EntityEdge(
|
|
||||||
uuid='e2',
|
|
||||||
source_node_uuid='2',
|
|
||||||
target_node_uuid='3',
|
|
||||||
name='LIKES',
|
|
||||||
fact='Node2 likes Node3',
|
|
||||||
created_at=now,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create NodeEdgeNodeTriplet objects
|
|
||||||
existing_edge = (node1, edge1, node2)
|
|
||||||
new_edge = (node2, edge2, node3)
|
|
||||||
|
|
||||||
# Prepare test input
|
|
||||||
existing_edges = [existing_edge]
|
|
||||||
new_edges = [new_edge]
|
|
||||||
|
|
||||||
# Create a current episode and previous episodes
|
|
||||||
current_episode = EpisodicNode(
|
|
||||||
name='Current Episode',
|
|
||||||
content='This is the current episode content.',
|
|
||||||
created_at=now,
|
|
||||||
valid_at=now,
|
|
||||||
source=EpisodeType.message,
|
|
||||||
source_description='Test episode for unit testing',
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
previous_episodes = [
|
|
||||||
EpisodicNode(
|
|
||||||
name='Previous Episode 1',
|
|
||||||
content='This is the content of previous episode 1.',
|
|
||||||
created_at=now - timedelta(days=1),
|
|
||||||
valid_at=now - timedelta(days=1),
|
|
||||||
source=EpisodeType.message,
|
|
||||||
source_description='Test previous episode 1 for unit testing',
|
|
||||||
group_id='1',
|
|
||||||
),
|
|
||||||
EpisodicNode(
|
|
||||||
name='Previous Episode 2',
|
|
||||||
content='This is the content of previous episode 2.',
|
|
||||||
created_at=now - timedelta(days=2),
|
|
||||||
valid_at=now - timedelta(days=2),
|
|
||||||
source=EpisodeType.message,
|
|
||||||
source_description='Test previous episode 2 for unit testing',
|
|
||||||
group_id='1',
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
result = prepare_invalidation_context(
|
|
||||||
existing_edges, new_edges, current_episode, previous_episodes
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert the result
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert 'existing_edges' in result
|
|
||||||
assert 'new_edges' in result
|
|
||||||
assert 'current_episode' in result
|
|
||||||
assert 'previous_episodes' in result
|
|
||||||
assert len(result['existing_edges']) == 1
|
|
||||||
assert len(result['new_edges']) == 1
|
|
||||||
assert result['current_episode'] == current_episode.content
|
|
||||||
assert len(result['previous_episodes']) == 2
|
|
||||||
|
|
||||||
# Check the format of the existing edge
|
|
||||||
existing_edge_str = result['existing_edges'][0]
|
|
||||||
assert edge1.uuid in existing_edge_str
|
|
||||||
assert node1.name in existing_edge_str
|
|
||||||
assert edge1.name in existing_edge_str
|
|
||||||
assert node2.name in existing_edge_str
|
|
||||||
assert edge1.fact in existing_edge_str
|
|
||||||
|
|
||||||
# Check the format of the new edge
|
|
||||||
new_edge_str = result['new_edges'][0]
|
|
||||||
assert edge2.uuid in new_edge_str
|
|
||||||
assert node2.name in new_edge_str
|
|
||||||
assert edge2.name in new_edge_str
|
|
||||||
assert node3.name in new_edge_str
|
|
||||||
assert edge2.fact in new_edge_str
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_invalidation_context_empty_input():
|
|
||||||
now = datetime.now()
|
|
||||||
current_episode = EpisodicNode(
|
|
||||||
name='Current Episode',
|
|
||||||
content='Empty episode',
|
|
||||||
created_at=now,
|
|
||||||
valid_at=now,
|
|
||||||
source=EpisodeType.message,
|
|
||||||
source_description='Test empty episode for unit testing',
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
result = prepare_invalidation_context([], [], current_episode, [])
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert 'existing_edges' in result
|
|
||||||
assert 'new_edges' in result
|
|
||||||
assert 'current_episode' in result
|
|
||||||
assert 'previous_episodes' in result
|
|
||||||
assert len(result['existing_edges']) == 0
|
|
||||||
assert len(result['new_edges']) == 0
|
|
||||||
assert result['current_episode'] == current_episode.content
|
|
||||||
assert len(result['previous_episodes']) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_invalidation_context_sorting():
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
# Create nodes
|
|
||||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
|
||||||
|
|
||||||
# Create edges with different timestamps
|
|
||||||
edge1 = EntityEdge(
|
|
||||||
uuid='e1',
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='2',
|
|
||||||
name='KNOWS',
|
|
||||||
fact='Node1 knows Node2',
|
|
||||||
created_at=now,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
edge2 = EntityEdge(
|
|
||||||
uuid='e2',
|
|
||||||
source_node_uuid='2',
|
|
||||||
target_node_uuid='1',
|
|
||||||
name='LIKES',
|
|
||||||
fact='Node2 likes Node1',
|
|
||||||
created_at=now + timedelta(hours=1),
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
|
|
||||||
edge_with_nodes1 = (node1, edge1, node2)
|
|
||||||
edge_with_nodes2 = (node2, edge2, node1)
|
|
||||||
|
|
||||||
# Prepare test input
|
|
||||||
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
|
||||||
|
|
||||||
# Create a current episode and previous episodes
|
|
||||||
current_episode = EpisodicNode(
|
|
||||||
name='Current Episode',
|
|
||||||
content='This is the current episode content.',
|
|
||||||
created_at=now,
|
|
||||||
valid_at=now,
|
|
||||||
source=EpisodeType.message,
|
|
||||||
source_description='Test episode for unit testing',
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
previous_episodes = [
|
|
||||||
EpisodicNode(
|
|
||||||
name='Previous Episode',
|
|
||||||
content='This is the content of a previous episode.',
|
|
||||||
created_at=now - timedelta(days=1),
|
|
||||||
valid_at=now - timedelta(days=1),
|
|
||||||
source=EpisodeType.message,
|
|
||||||
source_description='Test previous episode for unit testing',
|
|
||||||
group_id='1',
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
|
|
||||||
|
|
||||||
# Assert the result
|
|
||||||
assert len(result['existing_edges']) == 2
|
|
||||||
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
|
|
||||||
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
|
|
||||||
assert result['current_episode'] == current_episode.content
|
|
||||||
assert len(result['previous_episodes']) == 1
|
|
||||||
assert result['previous_episodes'][0] == previous_episodes[0].content
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractDateStringsFromEdge(unittest.TestCase):
|
|
||||||
def generate_entity_edge(self, valid_at, invalid_at):
|
|
||||||
return EntityEdge(
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='2',
|
|
||||||
name='KNOWS',
|
|
||||||
fact='Node1 knows Node2',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
valid_at=valid_at,
|
|
||||||
invalid_at=invalid_at,
|
|
||||||
group_id='1',
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_both_dates_present(self):
|
|
||||||
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0))
|
|
||||||
result = extract_date_strings_from_edge(edge)
|
|
||||||
expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)'
|
|
||||||
self.assertEqual(result, expected)
|
|
||||||
|
|
||||||
def test_only_valid_at_present(self):
|
|
||||||
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None)
|
|
||||||
result = extract_date_strings_from_edge(edge)
|
|
||||||
expected = 'Start Date: 2024-01-01T12:00:00'
|
|
||||||
self.assertEqual(result, expected)
|
|
||||||
|
|
||||||
def test_only_invalid_at_present(self):
|
|
||||||
edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0))
|
|
||||||
result = extract_date_strings_from_edge(edge)
|
|
||||||
expected = ' (End Date: 2024-01-02T12:00:00)'
|
|
||||||
self.assertEqual(result, expected)
|
|
||||||
|
|
||||||
def test_no_dates_present(self):
|
|
||||||
edge = self.generate_entity_edge(None, None)
|
|
||||||
result = extract_date_strings_from_edge(edge)
|
|
||||||
expected = ''
|
|
||||||
self.assertEqual(result, expected)
|
|
||||||
|
|
||||||
|
|
||||||
# Run the tests
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main([__file__])
|
|
||||||
|
|
@ -19,12 +19,14 @@ from datetime import datetime, timedelta
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from pytz import UTC
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||||
invalidate_edges,
|
extract_edge_dates,
|
||||||
|
get_edge_contradictions,
|
||||||
)
|
)
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -43,31 +45,26 @@ def setup_llm_client():
|
||||||
def create_test_data():
|
def create_test_data():
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
# Create nodes
|
|
||||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now)
|
|
||||||
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now)
|
|
||||||
|
|
||||||
# Create edges
|
# Create edges
|
||||||
edge1 = EntityEdge(
|
existing_edge = EntityEdge(
|
||||||
uuid='e1',
|
uuid='e1',
|
||||||
source_node_uuid='1',
|
source_node_uuid='1',
|
||||||
target_node_uuid='2',
|
target_node_uuid='2',
|
||||||
name='LIKES',
|
name='LIKES',
|
||||||
fact='Alice likes Bob',
|
fact='Alice likes Bob',
|
||||||
created_at=now - timedelta(days=1),
|
created_at=now - timedelta(days=1),
|
||||||
|
group_id='1',
|
||||||
)
|
)
|
||||||
edge2 = EntityEdge(
|
new_edge = EntityEdge(
|
||||||
uuid='e2',
|
uuid='e2',
|
||||||
source_node_uuid='1',
|
source_node_uuid='1',
|
||||||
target_node_uuid='2',
|
target_node_uuid='2',
|
||||||
name='DISLIKES',
|
name='DISLIKES',
|
||||||
fact='Alice dislikes Bob',
|
fact='Alice dislikes Bob',
|
||||||
created_at=now,
|
created_at=now,
|
||||||
|
group_id='1',
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_edge = (node1, edge1, node2)
|
|
||||||
new_edge = (node1, edge2, node2)
|
|
||||||
|
|
||||||
# Create current episode
|
# Create current episode
|
||||||
current_episode = EpisodicNode(
|
current_episode = EpisodicNode(
|
||||||
name='Current Episode',
|
name='Current Episode',
|
||||||
|
|
@ -97,46 +94,40 @@ def create_test_data():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges():
|
async def test_get_edge_contradictions():
|
||||||
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
|
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [existing_edge])
|
||||||
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(invalidated_edges) == 1
|
assert len(invalidated_edges) == 1
|
||||||
assert invalidated_edges[0].uuid == existing_edge[1].uuid
|
assert invalidated_edges[0].uuid == existing_edge.uuid
|
||||||
assert invalidated_edges[0].expired_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges_no_invalidation():
|
async def test_get_edge_contradictions_no_contradictions():
|
||||||
existing_edge, _, current_episode, previous_episodes = create_test_data()
|
_, new_edge, current_episode, previous_episodes = create_test_data()
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [])
|
||||||
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(invalidated_edges) == 0
|
assert len(invalidated_edges) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges_multiple_existing():
|
async def test_get_edge_contradictions_multiple_existing():
|
||||||
existing_edge1, new_edge = create_test_data()
|
existing_edge1, new_edge, _, _ = create_test_data()
|
||||||
existing_edge2, _ = create_test_data()
|
existing_edge2, _, _, _ = create_test_data()
|
||||||
existing_edge2[1].uuid = 'e3'
|
existing_edge2.uuid = 'e3'
|
||||||
existing_edge2[1].name = 'KNOWS'
|
existing_edge2.name = 'KNOWS'
|
||||||
existing_edge2[1].fact = 'Alice knows Bob'
|
existing_edge2.fact = 'Alice knows Bob'
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(
|
invalidated_edges = await get_edge_contradictions(
|
||||||
setup_llm_client(), [existing_edge1, existing_edge2], [new_edge]
|
setup_llm_client(), new_edge, [existing_edge1, existing_edge2]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(invalidated_edges) == 1
|
assert len(invalidated_edges) == 1
|
||||||
assert invalidated_edges[0].uuid == existing_edge1[1].uuid
|
assert invalidated_edges[0].uuid == existing_edge1.uuid
|
||||||
assert invalidated_edges[0].expired_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
# Helper function to create more complex test data
|
# Helper function to create more complex test data
|
||||||
|
|
@ -152,7 +143,7 @@ def create_complex_test_data():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create edges
|
# Create edges
|
||||||
edge1 = EntityEdge(
|
existing_edge1 = EntityEdge(
|
||||||
uuid='e1',
|
uuid='e1',
|
||||||
source_node_uuid='1',
|
source_node_uuid='1',
|
||||||
target_node_uuid='2',
|
target_node_uuid='2',
|
||||||
|
|
@ -161,7 +152,7 @@ def create_complex_test_data():
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=now - timedelta(days=5),
|
created_at=now - timedelta(days=5),
|
||||||
)
|
)
|
||||||
edge2 = EntityEdge(
|
existing_edge2 = EntityEdge(
|
||||||
uuid='e2',
|
uuid='e2',
|
||||||
source_node_uuid='1',
|
source_node_uuid='1',
|
||||||
target_node_uuid='3',
|
target_node_uuid='3',
|
||||||
|
|
@ -170,7 +161,7 @@ def create_complex_test_data():
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=now - timedelta(days=3),
|
created_at=now - timedelta(days=3),
|
||||||
)
|
)
|
||||||
edge3 = EntityEdge(
|
existing_edge3 = EntityEdge(
|
||||||
uuid='e3',
|
uuid='e3',
|
||||||
source_node_uuid='2',
|
source_node_uuid='2',
|
||||||
target_node_uuid='4',
|
target_node_uuid='4',
|
||||||
|
|
@ -180,10 +171,6 @@ def create_complex_test_data():
|
||||||
created_at=now - timedelta(days=2),
|
created_at=now - timedelta(days=2),
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_edge1 = (node1, edge1, node2)
|
|
||||||
existing_edge2 = (node1, edge2, node3)
|
|
||||||
existing_edge3 = (node2, edge3, node4)
|
|
||||||
|
|
||||||
return [existing_edge1, existing_edge2, existing_edge3], [
|
return [existing_edge1, existing_edge2, existing_edge3], [
|
||||||
node1,
|
node1,
|
||||||
node2,
|
node2,
|
||||||
|
|
@ -198,118 +185,61 @@ async def test_invalidate_edges_complex():
|
||||||
existing_edges, nodes = create_complex_test_data()
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
# Create a new edge that contradicts an existing one
|
# Create a new edge that contradicts an existing one
|
||||||
new_edge = (
|
new_edge = EntityEdge(
|
||||||
nodes[0],
|
uuid='e4',
|
||||||
EntityEdge(
|
source_node_uuid='1',
|
||||||
uuid='e4',
|
target_node_uuid='2',
|
||||||
source_node_uuid='1',
|
name='DISLIKES',
|
||||||
target_node_uuid='2',
|
fact='Alice dislikes Bob',
|
||||||
name='DISLIKES',
|
group_id='1',
|
||||||
fact='Alice dislikes Bob',
|
created_at=datetime.now(),
|
||||||
group_id='1',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
),
|
|
||||||
nodes[1],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
||||||
assert len(invalidated_edges) == 1
|
assert len(invalidated_edges) == 1
|
||||||
assert invalidated_edges[0].uuid == 'e1'
|
assert invalidated_edges[0].uuid == 'e1'
|
||||||
assert invalidated_edges[0].expired_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges_temporal_update():
|
async def test_get_edge_contradictions_temporal_update():
|
||||||
existing_edges, nodes = create_complex_test_data()
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
# Create a new edge that updates an existing one with new information
|
# Create a new edge that updates an existing one with new information
|
||||||
new_edge = (
|
new_edge = EntityEdge(
|
||||||
nodes[1],
|
uuid='e5',
|
||||||
EntityEdge(
|
source_node_uuid='2',
|
||||||
uuid='e5',
|
target_node_uuid='4',
|
||||||
source_node_uuid='2',
|
name='LEFT_JOB',
|
||||||
target_node_uuid='4',
|
fact='Bob no longer works at at Company XYZ',
|
||||||
name='LEFT_JOB',
|
group_id='1',
|
||||||
fact='Bob left his job at Company XYZ',
|
created_at=datetime.now(),
|
||||||
group_id='1',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
),
|
|
||||||
nodes[3],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
||||||
assert len(invalidated_edges) == 1
|
assert len(invalidated_edges) == 1
|
||||||
assert invalidated_edges[0].uuid == 'e3'
|
assert invalidated_edges[0].uuid == 'e3'
|
||||||
assert invalidated_edges[0].expired_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges_multiple_invalidations():
|
async def test_get_edge_contradictions_no_effect():
|
||||||
existing_edges, nodes = create_complex_test_data()
|
|
||||||
|
|
||||||
# Create new edges that invalidate multiple existing edges
|
|
||||||
new_edge1 = (
|
|
||||||
nodes[0],
|
|
||||||
EntityEdge(
|
|
||||||
uuid='e6',
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='2',
|
|
||||||
name='ENEMIES_WITH',
|
|
||||||
fact='Alice and Bob are now enemies',
|
|
||||||
group_id='1',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
),
|
|
||||||
nodes[1],
|
|
||||||
)
|
|
||||||
new_edge2 = (
|
|
||||||
nodes[0],
|
|
||||||
EntityEdge(
|
|
||||||
uuid='e7',
|
|
||||||
source_node_uuid='1',
|
|
||||||
target_node_uuid='3',
|
|
||||||
name='ENDED_FRIENDSHIP',
|
|
||||||
fact='Alice ended her friendship with Charlie',
|
|
||||||
group_id='1',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
),
|
|
||||||
nodes[2],
|
|
||||||
)
|
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(
|
|
||||||
setup_llm_client(), existing_edges, [new_edge1, new_edge2]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(invalidated_edges) == 2
|
|
||||||
assert set(edge.uuid for edge in invalidated_edges) == {'e1', 'e2'}
|
|
||||||
for edge in invalidated_edges:
|
|
||||||
assert edge.expired_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.integration
|
|
||||||
async def test_invalidate_edges_no_effect():
|
|
||||||
existing_edges, nodes = create_complex_test_data()
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
# Create a new edge that doesn't invalidate any existing edges
|
# Create a new edge that doesn't invalidate any existing edges
|
||||||
new_edge = (
|
new_edge = EntityEdge(
|
||||||
nodes[2],
|
uuid='e8',
|
||||||
EntityEdge(
|
source_node_uuid='3',
|
||||||
uuid='e8',
|
target_node_uuid='4',
|
||||||
source_node_uuid='3',
|
name='APPLIED_TO',
|
||||||
target_node_uuid='4',
|
fact='Charlie applied to Company XYZ',
|
||||||
name='APPLIED_TO',
|
group_id='1',
|
||||||
fact='Charlie applied to Company XYZ',
|
created_at=datetime.now(),
|
||||||
group_id='1',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
),
|
|
||||||
nodes[3],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
||||||
assert len(invalidated_edges) == 0
|
assert len(invalidated_edges) == 0
|
||||||
|
|
||||||
|
|
@ -320,28 +250,82 @@ async def test_invalidate_edges_partial_update():
|
||||||
existing_edges, nodes = create_complex_test_data()
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
# Create a new edge that partially updates an existing one
|
# Create a new edge that partially updates an existing one
|
||||||
new_edge = (
|
new_edge = EntityEdge(
|
||||||
nodes[1],
|
uuid='e9',
|
||||||
EntityEdge(
|
source_node_uuid='2',
|
||||||
uuid='e9',
|
target_node_uuid='4',
|
||||||
source_node_uuid='2',
|
name='CHANGED_POSITION',
|
||||||
target_node_uuid='4',
|
fact='Bob changed his position at Company XYZ',
|
||||||
name='CHANGED_POSITION',
|
group_id='1',
|
||||||
fact='Bob changed his position at Company XYZ',
|
created_at=datetime.now(),
|
||||||
group_id='1',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
),
|
|
||||||
nodes[3],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
||||||
assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
|
assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
previous_episodes = [
|
||||||
|
EpisodicNode(
|
||||||
|
name='Previous Episode 1',
|
||||||
|
content='Bob: I work at XYZ company',
|
||||||
|
created_at=now - timedelta(days=2),
|
||||||
|
valid_at=now - timedelta(days=2),
|
||||||
|
source=EpisodeType.message,
|
||||||
|
source_description='Test previous episode for unit testing',
|
||||||
|
group_id='1',
|
||||||
|
),
|
||||||
|
EpisodicNode(
|
||||||
|
name='Previous Episode 2',
|
||||||
|
content="Alice: That's really cool!",
|
||||||
|
created_at=now - timedelta(days=1),
|
||||||
|
valid_at=now - timedelta(days=1),
|
||||||
|
source=EpisodeType.message,
|
||||||
|
source_description='Test previous episode for unit testing',
|
||||||
|
group_id='1',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
episode = EpisodicNode(
|
||||||
|
name='Previous Episode',
|
||||||
|
content='Bob: It was cool, but I no longer work at company XYZ',
|
||||||
|
created_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
source=EpisodeType.message,
|
||||||
|
source_description='Test previous episode for unit testing',
|
||||||
|
group_id='1',
|
||||||
|
)
|
||||||
|
|
||||||
|
return episode, previous_episodes
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
async def test_invalidate_edges_empty_inputs():
|
async def test_extract_edge_dates():
|
||||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [], [])
|
episode, previous_episodes = create_data_for_temporal_extraction()
|
||||||
|
|
||||||
assert len(invalidated_edges) == 0
|
# Create a new edge that partially updates an existing one
|
||||||
|
new_edge = EntityEdge(
|
||||||
|
uuid='e9',
|
||||||
|
source_node_uuid='2',
|
||||||
|
target_node_uuid='4',
|
||||||
|
name='LEFT_JOB',
|
||||||
|
fact='Bob no longer works at Company XYZ',
|
||||||
|
group_id='1',
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_at, invalid_at = await extract_edge_dates(
|
||||||
|
setup_llm_client(), new_edge, episode, previous_episodes
|
||||||
|
)
|
||||||
|
|
||||||
|
assert valid_at == episode.valid_at
|
||||||
|
assert invalid_at is None
|
||||||
|
|
||||||
|
|
||||||
|
# Run the tests
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue