add_fact endpoint (#207)

* add_fact endpoint

* bump version

* add edge invalidation

* update
This commit is contained in:
Preston Rasmussen 2024-11-06 09:12:21 -05:00 committed by GitHub
parent 6536401c8c
commit 3199e893ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 196 additions and 87 deletions

View file

@ -19,7 +19,7 @@ import json
import logging import logging
import os import os
import sys import sys
from datetime import datetime from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
@ -78,7 +78,7 @@ async def add_messages(client: Graphiti):
name=f'Message {i}', name=f'Message {i}',
episode_body=message, episode_body=message,
source=EpisodeType.message, source=EpisodeType.message,
reference_time=datetime.now(), reference_time=datetime.now(timezone.utc),
source_description='Shoe conversation', source_description='Shoe conversation',
) )
@ -105,7 +105,7 @@ async def ingest_products_data(client: Graphiti):
content=str(product), content=str(product),
source_description='Allbirds products', source_description='Allbirds products',
source=EpisodeType.json, source=EpisodeType.json,
reference_time=datetime.now(), reference_time=datetime.now(timezone.utc),
) )
for i, product in enumerate(products) for i, product in enumerate(products)
] ]

View file

@ -15,7 +15,7 @@ limitations under the License.
""" """
import json import json
from datetime import datetime from datetime import datetime, timezone
from pydantic import BaseModel from pydantic import BaseModel
@ -45,7 +45,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
ParsedMscMessage( ParsedMscMessage(
speaker_name=speakers[speaker_idx], speaker_name=speakers[speaker_idx],
content=content, content=content,
actual_timestamp=datetime.now(), actual_timestamp=datetime.now(timezone.utc),
group_id=str(i), group_id=str(i),
) )
) )
@ -60,7 +60,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
ParsedMscMessage( ParsedMscMessage(
speaker_name=speakers[speaker_idx], speaker_name=speakers[speaker_idx],
content=content, content=content,
actual_timestamp=datetime.now(), actual_timestamp=datetime.now(timezone.utc),
group_id=str(i), group_id=str(i),
) )
) )

View file

@ -1,6 +1,6 @@
import os import os
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel
@ -61,7 +61,7 @@ def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[Par
break break
# Calculate the start time # Calculate the start time
now = datetime.now() now = datetime.now(timezone.utc)
podcast_start_time = now - last_timestamp podcast_start_time = now - last_timestamp
for message in messages: for message in messages:

View file

@ -18,7 +18,7 @@ import asyncio
import logging import logging
import os import os
import sys import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from dotenv import load_dotenv from dotenv import load_dotenv
@ -63,7 +63,7 @@ async def main():
messages = get_wizard_of_oz_messages() messages = get_wizard_of_oz_messages()
print(messages) print(messages)
print(len(messages)) print(len(messages))
now = datetime.now() now = datetime.now(timezone.utc)
# episodes: list[BulkEpisode] = [ # episodes: list[BulkEpisode] = [
# BulkEpisode( # BulkEpisode(
# name=f'Chapter {i + 1}', # name=f'Chapter {i + 1}',

View file

@ -16,7 +16,7 @@ limitations under the License.
import asyncio import asyncio
import logging import logging
from datetime import datetime from datetime import datetime, timezone
from time import time from time import time
from dotenv import load_dotenv from dotenv import load_dotenv
@ -65,7 +65,9 @@ from graphiti_core.utils.maintenance.community_operations import (
update_community, update_community,
) )
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
dedupe_extracted_edge,
extract_edges, extract_edges,
resolve_edge_contradictions,
resolve_extracted_edges, resolve_extracted_edges,
) )
from graphiti_core.utils.maintenance.graph_data_operations import ( from graphiti_core.utils.maintenance.graph_data_operations import (
@ -76,6 +78,7 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes, extract_nodes,
resolve_extracted_nodes, resolve_extracted_nodes,
) )
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -312,7 +315,7 @@ class Graphiti:
start = time() start = time()
entity_edges: list[EntityEdge] = [] entity_edges: list[EntityEdge] = []
now = datetime.now() now = datetime.now(timezone.utc)
previous_episodes = await self.retrieve_episodes( previous_episodes = await self.retrieve_episodes(
reference_time, last_n=3, group_ids=[group_id] reference_time, last_n=3, group_ids=[group_id]
@ -448,7 +451,6 @@ class Graphiti:
episode.entity_edges = [edge.uuid for edge in entity_edges] episode.entity_edges = [edge.uuid for edge in entity_edges]
# Future optimization would be using batch operations to save nodes and edges
if not self.store_raw_episode_content: if not self.store_raw_episode_content:
episode.content = '' episode.content = ''
@ -511,7 +513,7 @@ class Graphiti:
""" """
try: try:
start = time() start = time()
now = datetime.now() now = datetime.now(timezone.utc)
episodes = [ episodes = [
EpisodicNode( EpisodicNode(
@ -760,3 +762,36 @@ class Graphiti:
communities = await get_communities_by_nodes(self.driver, nodes) communities = await get_communities_by_nodes(self.driver, nodes)
return SearchResults(edges=edges, nodes=nodes, communities=communities) return SearchResults(edges=edges, nodes=nodes, communities=communities)
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
if source_node.name_embedding is None:
await source_node.generate_name_embedding(self.embedder)
if target_node.name_embedding is None:
await target_node.generate_name_embedding(self.embedder)
if edge.fact_embedding is None:
await edge.generate_embedding(self.embedder)
resolved_nodes, _ = await resolve_extracted_nodes(
self.llm_client,
[source_node, target_node],
[
await get_relevant_nodes([source_node], self.driver),
await get_relevant_nodes([target_node], self.driver),
],
)
related_edges = await get_relevant_edges(
self.driver,
[edge],
source_node_uuid=resolved_nodes[0].uuid,
target_node_uuid=resolved_nodes[1].uuid,
)
resolved_edge = await dedupe_extracted_edge(self.llm_client, edge, related_edges)
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
await add_nodes_and_edges_bulk(
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
)

View file

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
EPISODIC_EDGE_SAVE = """ EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid}) MATCH (node:Entity {uuid: $entity_uuid})

View file

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
EPISODIC_NODE_SAVE = """ EPISODIC_NODE_SAVE = """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,

View file

@ -16,7 +16,7 @@ limitations under the License.
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime, timezone
from enum import Enum from enum import Enum
from time import time from time import time
from typing import Any from typing import Any
@ -78,7 +78,7 @@ class Node(BaseModel, ABC):
name: str = Field(description='name of the node') name: str = Field(description='name of the node')
group_id: str = Field(description='partition of the graph') group_id: str = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list) labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now()) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@abstractmethod @abstractmethod
async def save(self, driver: AsyncDriver): ... async def save(self, driver: AsyncDriver): ...

View file

@ -18,7 +18,7 @@ import asyncio
import logging import logging
import typing import typing
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime, timezone
from math import ceil from math import ceil
from neo4j import AsyncDriver, AsyncManagedTransaction from neo4j import AsyncDriver, AsyncManagedTransaction
@ -385,7 +385,7 @@ async def extract_edge_dates_bulk(
edge.valid_at = valid_at edge.valid_at = valid_at
edge.invalid_at = invalid_at edge.invalid_at = invalid_at
if edge.invalid_at: if edge.invalid_at:
edge.expired_at = datetime.now() edge.expired_at = datetime.now(timezone.utc)
return edges return edges

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime, timezone
from neo4j import AsyncDriver from neo4j import AsyncDriver
from pydantic import BaseModel from pydantic import BaseModel
@ -178,7 +178,7 @@ async def build_community(
summary = summaries[0] summary = summaries[0]
name = await generate_summary_description(llm_client, summary) name = await generate_summary_description(llm_client, summary)
now = datetime.now() now = datetime.now(timezone.utc)
community_node = CommunityNode( community_node = CommunityNode(
name=name, name=name,
group_id=community_cluster[0].group_id, group_id=community_cluster[0].group_id,

View file

@ -16,7 +16,7 @@ limitations under the License.
import asyncio import asyncio
import logging import logging
from datetime import datetime from datetime import datetime, timezone
from time import time from time import time
from typing import List from typing import List
@ -110,7 +110,7 @@ async def extract_edges(
group_id=group_id, group_id=group_id,
fact=edge_data['fact'], fact=edge_data['fact'],
episodes=[episode.uuid], episodes=[episode.uuid],
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
valid_at=None, valid_at=None,
invalid_at=None, invalid_at=None,
) )
@ -205,39 +205,9 @@ async def resolve_extracted_edges(
return resolved_edges, invalidated_edges return resolved_edges, invalidated_edges
async def resolve_extracted_edge( def resolve_edge_contradictions(
llm_client: LLMClient, resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
extracted_edge: EntityEdge, ) -> list[EntityEdge]:
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = datetime.now()
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if (
candidate.valid_at is not None and resolved_edge.valid_at is not None
) and candidate.valid_at > resolved_edge.valid_at:
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
# Determine which contradictory edges need to be expired # Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = [] invalidated_edges: list[EntityEdge] = []
for edge in invalidation_candidates: for edge in invalidation_candidates:
@ -259,9 +229,50 @@ async def resolve_extracted_edge(
and edge.valid_at < resolved_edge.valid_at and edge.valid_at < resolved_edge.valid_at
): ):
edge.invalid_at = resolved_edge.valid_at edge.invalid_at = resolved_edge.valid_at
edge.expired_at = edge.expired_at if edge.expired_at is not None else now edge.expired_at = (
edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc)
)
invalidated_edges.append(edge) invalidated_edges.append(edge)
return invalidated_edges
async def resolve_extracted_edge(
llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = datetime.now(timezone.utc)
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if (
candidate.valid_at is not None and resolved_edge.valid_at is not None
) and candidate.valid_at > resolved_edge.valid_at:
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
# Determine which contradictory edges need to be expired
invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
return resolved_edge, invalidated_edges return resolved_edge, invalidated_edges

View file

@ -16,7 +16,7 @@ limitations under the License.
import asyncio import asyncio
import logging import logging
from datetime import datetime from datetime import datetime, timezone
from time import time from time import time
from typing import Any from typing import Any
@ -113,7 +113,7 @@ async def extract_nodes(
group_id=episode.group_id, group_id=episode.group_id,
labels=node_data['labels'], labels=node_data['labels'],
summary=node_data['summary'], summary=node_data['summary'],
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
) )
new_nodes.append(new_node) new_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "graphiti-core" name = "graphiti-core"
version = "0.3.21" version = "0.4.0"
description = "A temporal graph building library" description = "A temporal graph building library"
authors = [ authors = [
"Paul Paliychuk <paul@getzep.com>", "Paul Paliychuk <paul@getzep.com>",

View file

@ -1,4 +1,4 @@
from datetime import datetime from datetime import datetime, timezone
from fastapi import APIRouter, status from fastapi import APIRouter, status
@ -36,7 +36,7 @@ async def get_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK) @router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep): async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
episodes = await graphiti.retrieve_episodes( episodes = await graphiti.retrieve_episodes(
group_ids=[group_id], last_n=last_n, reference_time=datetime.now() group_ids=[group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
) )
return episodes return episodes

View file

@ -18,7 +18,7 @@ import asyncio
import logging import logging
import os import os
import sys import sys
from datetime import datetime from datetime import datetime, timezone
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
@ -66,7 +66,39 @@ def setup_logging():
async def test_graphiti_init(): async def test_graphiti_init():
logger = setup_logging() logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None) now = datetime.now(timezone.utc)
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id='test',
)
bob_node = EntityNode(
name='Bob',
labels=[],
created_at=now,
summary='Bob summary',
group_id='test',
)
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
group_id='test',
)
await graphiti.add_triplet(alice_node, entity_edge, bob_node)
episodes = await graphiti.retrieve_episodes(datetime.now(timezone.utc), group_ids=None)
episode_uuids = [episode.uuid for episode in episodes] episode_uuids = [episode.uuid for episode in episodes]
results = await graphiti._search( results = await graphiti._search(
@ -92,7 +124,7 @@ async def test_graph_integration():
embedder = client.embedder embedder = client.embedder
driver = client.driver driver = client.driver
now = datetime.now() now = datetime.now(timezone.utc)
episode = EpisodicNode( episode = EpisodicNode(
name='test_episode', name='test_episode',
labels=[], labels=[],

View file

@ -1,4 +1,4 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@ -23,7 +23,7 @@ def mock_extracted_edge():
group_id='group_1', group_id='group_1',
fact='Test fact', fact='Test fact',
episodes=['episode_1'], episodes=['episode_1'],
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
valid_at=None, valid_at=None,
invalid_at=None, invalid_at=None,
) )
@ -39,8 +39,8 @@ def mock_related_edges():
group_id='group_1', group_id='group_1',
fact='Related fact', fact='Related fact',
episodes=['episode_2'], episodes=['episode_2'],
created_at=datetime.now() - timedelta(days=1), created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=datetime.now() - timedelta(days=1), valid_at=datetime.now(timezone.utc) - timedelta(days=1),
invalid_at=None, invalid_at=None,
) )
] ]
@ -56,8 +56,8 @@ def mock_existing_edges():
group_id='group_1', group_id='group_1',
fact='Existing fact', fact='Existing fact',
episodes=['episode_3'], episodes=['episode_3'],
created_at=datetime.now() - timedelta(days=2), created_at=datetime.now(timezone.utc) - timedelta(days=2),
valid_at=datetime.now() - timedelta(days=2), valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None, invalid_at=None,
) )
] ]
@ -68,7 +68,7 @@ def mock_current_episode():
return EpisodicNode( return EpisodicNode(
uuid='episode_1', uuid='episode_1',
content='Current episode content', content='Current episode content',
valid_at=datetime.now(), valid_at=datetime.now(timezone.utc),
name='Current Episode', name='Current Episode',
group_id='group_1', group_id='group_1',
source='message', source='message',
@ -82,7 +82,7 @@ def mock_previous_episodes():
EpisodicNode( EpisodicNode(
uuid='episode_2', uuid='episode_2',
content='Previous episode content', content='Previous episode content',
valid_at=datetime.now() - timedelta(days=1), valid_at=datetime.now(timezone.utc) - timedelta(days=1),
name='Previous Episode', name='Previous Episode',
group_id='group_1', group_id='group_1',
source='message', source='message',
@ -144,8 +144,8 @@ async def test_resolve_extracted_edge_with_dates(
mock_previous_episodes, mock_previous_episodes,
monkeypatch: MonkeyPatch, monkeypatch: MonkeyPatch,
): ):
valid_at = datetime.now() - timedelta(days=1) valid_at = datetime.now(timezone.utc) - timedelta(days=1)
invalid_at = datetime.now() + timedelta(days=1) invalid_at = datetime.now(timezone.utc) + timedelta(days=1)
# Mock the function calls # Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge) dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
@ -189,7 +189,7 @@ async def test_resolve_extracted_edge_with_invalidation(
mock_previous_episodes, mock_previous_episodes,
monkeypatch: MonkeyPatch, monkeypatch: MonkeyPatch,
): ):
valid_at = datetime.now() - timedelta(days=1) valid_at = datetime.now(timezone.utc) - timedelta(days=1)
mock_extracted_edge.valid_at = valid_at mock_extracted_edge.valid_at = valid_at
invalidation_candidate = EntityEdge( invalidation_candidate = EntityEdge(
@ -199,8 +199,8 @@ async def test_resolve_extracted_edge_with_invalidation(
group_id='group_1', group_id='group_1',
fact='Invalidation candidate fact', fact='Invalidation candidate fact',
episodes=['episode_4'], episodes=['episode_4'],
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
valid_at=datetime.now() - timedelta(days=2), valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None, invalid_at=None,
) )

View file

@ -15,11 +15,10 @@ limitations under the License.
""" """
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from pytz import UTC
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMConfig, OpenAIClient from graphiti_core.llm_client import LLMConfig, OpenAIClient
@ -43,7 +42,7 @@ def setup_llm_client():
def create_test_data(): def create_test_data():
now = datetime.now() now = datetime.now(timezone.utc)
# Create edges # Create edges
existing_edge = EntityEdge( existing_edge = EntityEdge(
@ -132,7 +131,7 @@ async def test_get_edge_contradictions_multiple_existing():
# Helper function to create more complex test data # Helper function to create more complex test data
def create_complex_test_data(): def create_complex_test_data():
now = datetime.now() now = datetime.now(timezone.utc)
# Create nodes # Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1') node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
@ -192,7 +191,7 @@ async def test_invalidate_edges_complex():
name='DISLIKES', name='DISLIKES',
fact='Alice dislikes Bob', fact='Alice dislikes Bob',
group_id='1', group_id='1',
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
) )
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -214,7 +213,7 @@ async def test_get_edge_contradictions_temporal_update():
name='LEFT_JOB', name='LEFT_JOB',
fact='Bob no longer works at at Company XYZ', fact='Bob no longer works at at Company XYZ',
group_id='1', group_id='1',
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
) )
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -236,7 +235,7 @@ async def test_get_edge_contradictions_no_effect():
name='APPLIED_TO', name='APPLIED_TO',
fact='Charlie applied to Company XYZ', fact='Charlie applied to Company XYZ',
group_id='1', group_id='1',
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
) )
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -257,7 +256,7 @@ async def test_invalidate_edges_partial_update():
name='CHANGED_POSITION', name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ', fact='Bob changed his position at Company XYZ',
group_id='1', group_id='1',
created_at=datetime.now(), created_at=datetime.now(timezone.utc),
) )
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -266,7 +265,7 @@ async def test_invalidate_edges_partial_update():
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]: def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
now = datetime.now(UTC) now = datetime.now(timezone.utc)
previous_episodes = [ previous_episodes = [
EpisodicNode( EpisodicNode(
@ -315,7 +314,7 @@ async def test_extract_edge_dates():
name='LEFT_JOB', name='LEFT_JOB',
fact='Bob no longer works at Company XYZ', fact='Bob no longer works at Company XYZ',
group_id='1', group_id='1',
created_at=datetime.now(UTC), created_at=datetime.now(timezone.utc),
) )
valid_at, invalid_at = await extract_edge_dates( valid_at, invalid_at = await extract_edge_dates(