From 445dccc021166a4baac8bca13ea227d6e9f39011 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:36:04 -0800 Subject: [PATCH] refactor: use `utc_now()` for consistent UTC datetime handling (#234) * ensure utc timezones * fix: dep cycle --------- Co-authored-by: paulpaliychuk --- graphiti_core/graphiti.py | 13 +++--- graphiti_core/nodes.py | 5 ++- graphiti_core/utils/__init__.py | 15 ------- graphiti_core/utils/bulk_utils.py | 11 +++-- graphiti_core/utils/datetime_utils.py | 42 +++++++++++++++++++ .../utils/maintenance/community_operations.py | 6 +-- .../utils/maintenance/edge_operations.py | 26 +++++++----- .../utils/maintenance/node_operations.py | 4 +- .../utils/maintenance/temporal_operations.py | 9 ++-- server/graph_service/dto/common.py | 5 +-- server/graph_service/routers/ingest.py | 2 +- .../test_temporal_operations_int.py | 19 +++++---- 12 files changed, 97 insertions(+), 60 deletions(-) create mode 100644 graphiti_core/utils/datetime_utils.py diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 0f428f51..574d6189 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -16,7 +16,7 @@ limitations under the License. import asyncio import logging -from datetime import datetime, timezone +from datetime import datetime from time import time from dotenv import load_dotenv @@ -43,10 +43,6 @@ from graphiti_core.search.search_utils import ( get_relevant_edges, get_relevant_nodes, ) -from graphiti_core.utils import ( - build_episodic_edges, - retrieve_episodes, -) from graphiti_core.utils.bulk_utils import ( RawEpisode, add_nodes_and_edges_bulk, @@ -57,12 +53,14 @@ from graphiti_core.utils.bulk_utils import ( resolve_edge_pointers, retrieve_previous_episodes_bulk, ) +from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.community_operations import ( build_communities, remove_communities, update_community, ) from graphiti_core.utils.maintenance.edge_operations import ( + build_episodic_edges, dedupe_extracted_edge, extract_edges, resolve_edge_contradictions, @@ -71,6 +69,7 @@ from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.graph_data_operations import ( EPISODE_WINDOW_LEN, build_indices_and_constraints, + retrieve_episodes, ) from graphiti_core.utils.maintenance.node_operations import ( extract_nodes, @@ -313,7 +312,7 @@ class Graphiti: start = time() entity_edges: list[EntityEdge] = [] - now = datetime.now(timezone.utc) + now = utc_now() previous_episodes = await self.retrieve_episodes( reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id] @@ -522,7 +521,7 @@ class Graphiti: """ try: start = time() - now = datetime.now(timezone.utc) + now = utc_now() episodes = [ EpisodicNode( diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index a4dece28..3d635060 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -16,7 +16,7 @@ limitations under the License. import logging from abc import ABC, abstractmethod -from datetime import datetime, timezone +from datetime import datetime from enum import Enum from time import time from typing import Any @@ -34,6 +34,7 @@ from graphiti_core.models.nodes.node_db_queries import ( ENTITY_NODE_SAVE, EPISODIC_NODE_SAVE, ) +from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) @@ -79,7 +80,7 @@ class Node(BaseModel, ABC): name: str = Field(description='name of the node') group_id: str = Field(description='partition of the graph') labels: list[str] = Field(default_factory=list) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: utc_now()) @abstractmethod async def save(self, driver: AsyncDriver): ... diff --git a/graphiti_core/utils/__init__.py b/graphiti_core/utils/__init__.py index 54642349..e69de29b 100644 --- a/graphiti_core/utils/__init__.py +++ b/graphiti_core/utils/__init__.py @@ -1,15 +0,0 @@ -from .maintenance import ( - build_episodic_edges, - clear_data, - extract_edges, - extract_nodes, - retrieve_episodes, -) - -__all__ = [ - 'extract_edges', - 'build_episodic_edges', - 'extract_nodes', - 'clear_data', - 'retrieve_episodes', -] diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 7b4ebf02..5deb224d 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -18,7 +18,7 @@ import asyncio import logging import typing from collections import defaultdict -from datetime import datetime, timezone +from datetime import datetime from math import ceil from neo4j import AsyncDriver, AsyncManagedTransaction @@ -37,14 +37,17 @@ from graphiti_core.models.nodes.node_db_queries import ( ) from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes -from graphiti_core.utils import retrieve_episodes +from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.edge_operations import ( build_episodic_edges, dedupe_edge_list, dedupe_extracted_edges, extract_edges, ) -from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN +from graphiti_core.utils.maintenance.graph_data_operations import ( + EPISODE_WINDOW_LEN, + retrieve_episodes, +) from graphiti_core.utils.maintenance.node_operations import ( dedupe_extracted_nodes, dedupe_node_list, @@ -385,7 +388,7 @@ async def extract_edge_dates_bulk( edge.valid_at = valid_at edge.invalid_at = invalid_at if edge.invalid_at: - edge.expired_at = datetime.now(timezone.utc) + edge.expired_at = utc_now() return edges diff --git a/graphiti_core/utils/datetime_utils.py b/graphiti_core/utils/datetime_utils.py new file mode 100644 index 00000000..71550108 --- /dev/null +++ b/graphiti_core/utils/datetime_utils.py @@ -0,0 +1,42 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime, timezone + + +def utc_now() -> datetime: + """Returns the current UTC datetime with timezone information.""" + return datetime.now(timezone.utc) + + +def ensure_utc(dt: datetime | None) -> datetime | None: + """ + Ensures a datetime is timezone-aware and in UTC. + If the datetime is naive (no timezone), assumes it's in UTC. + If the datetime has a different timezone, converts it to UTC. + Returns None if input is None. + """ + if dt is None: + return None + + if dt.tzinfo is None: + # If datetime is naive, assume it's UTC + return dt.replace(tzinfo=timezone.utc) + elif dt.tzinfo != timezone.utc: + # If datetime has a different timezone, convert to UTC + return dt.astimezone(timezone.utc) + + return dt diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index fc71f707..a585c149 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -1,7 +1,6 @@ import asyncio import logging from collections import defaultdict -from datetime import datetime, timezone from neo4j import AsyncDriver from pydantic import BaseModel @@ -17,6 +16,7 @@ from graphiti_core.nodes import ( ) from graphiti_core.prompts import prompt_library from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription +from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.edge_operations import build_community_edges MAX_COMMUNITY_BUILD_CONCURRENCY = 10 @@ -180,7 +180,7 @@ async def build_community( summary = summaries[0] name = await generate_summary_description(llm_client, summary) - now = datetime.now(timezone.utc) + now = utc_now() community_node = CommunityNode( name=name, group_id=community_cluster[0].group_id, @@ -307,7 +307,7 @@ async def update_community( community.name = new_name if is_new: - community_edge = (build_community_edges([entity], community, datetime.now(timezone.utc)))[0] + community_edge = (build_community_edges([entity], community, utc_now()))[0] await community_edge.save(driver) await community.generate_name_embedding(embedder) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 1279cf14..e3fa4f7a 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -16,7 +16,7 @@ limitations under the License. import asyncio import logging -from datetime import datetime, timezone +from datetime import datetime from time import time from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge @@ -26,6 +26,7 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts +from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.temporal_operations import ( extract_edge_dates, get_edge_contradictions, @@ -132,7 +133,7 @@ async def extract_edges( group_id=group_id, fact=edge_data.get('fact', ''), episodes=[episode.uuid], - created_at=datetime.now(timezone.utc), + created_at=utc_now(), valid_at=None, invalid_at=None, ) @@ -251,9 +252,7 @@ def resolve_edge_contradictions( and edge.valid_at < resolved_edge.valid_at ): edge.invalid_at = resolved_edge.valid_at - edge.expired_at = ( - edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc) - ) + edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now() invalidated_edges.append(edge) return invalidated_edges @@ -273,11 +272,12 @@ async def resolve_extracted_edge( get_edge_contradictions(llm_client, extracted_edge, existing_edges), ) - now = datetime.now(timezone.utc) + now = utc_now() - resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at - resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at - if invalid_at is not None and resolved_edge.expired_at is None: + resolved_edge.valid_at = valid_at if valid_at else resolved_edge.valid_at + resolved_edge.invalid_at = invalid_at if invalid_at else resolved_edge.invalid_at + + if invalid_at and not resolved_edge.expired_at: resolved_edge.expired_at = now # Determine if the new_edge needs to be expired @@ -285,8 +285,12 @@ async def resolve_extracted_edge( invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at)) for candidate in invalidation_candidates: if ( - candidate.valid_at is not None and resolved_edge.valid_at is not None - ) and candidate.valid_at > resolved_edge.valid_at: + candidate.valid_at + and resolved_edge.valid_at + and candidate.valid_at.tzinfo + and resolved_edge.valid_at.tzinfo + and candidate.valid_at > resolved_edge.valid_at + ): # Expire new edge since we have information about more recent events resolved_edge.invalid_at = candidate.valid_at resolved_edge.expired_at = now diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 08835023..22201e22 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -16,7 +16,6 @@ limitations under the License. import asyncio import logging -from datetime import datetime, timezone from time import time from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS @@ -26,6 +25,7 @@ from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities from graphiti_core.prompts.summarize_nodes import Summary +from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) @@ -155,7 +155,7 @@ async def extract_nodes( group_id=episode.group_id, labels=['Entity'], summary='', - created_at=datetime.now(timezone.utc), + created_at=utc_now(), ) new_nodes.append(new_node) logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') diff --git a/graphiti_core/utils/maintenance/temporal_operations.py b/graphiti_core/utils/maintenance/temporal_operations.py index 6028ecfb..12cdf0e9 100644 --- a/graphiti_core/utils/maintenance/temporal_operations.py +++ b/graphiti_core/utils/maintenance/temporal_operations.py @@ -9,7 +9,7 @@ You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ @@ -24,6 +24,7 @@ from graphiti_core.nodes import EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.extract_edge_dates import EdgeDates from graphiti_core.prompts.invalidate_edges import InvalidatedEdges +from graphiti_core.utils.datetime_utils import ensure_utc logger = logging.getLogger(__name__) @@ -52,13 +53,15 @@ async def extract_edge_dates( if valid_at: try: - valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00')) + valid_at_datetime = ensure_utc(datetime.fromisoformat(valid_at.replace('Z', '+00:00'))) except ValueError as e: logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}') if invalid_at: try: - invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00')) + invalid_at_datetime = ensure_utc( + datetime.fromisoformat(invalid_at.replace('Z', '+00:00')) + ) except ValueError as e: logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}') diff --git a/server/graph_service/dto/common.py b/server/graph_service/dto/common.py index 9d9e4b76..5103470e 100644 --- a/server/graph_service/dto/common.py +++ b/server/graph_service/dto/common.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import Literal +from graphiti_core.utils.datetime_utils import utc_now from pydantic import BaseModel, Field @@ -21,9 +22,7 @@ class Message(BaseModel): role: str | None = Field( description='The custom role of the message to be used alongside role_type (user name, bot name, etc.)', ) - timestamp: datetime = Field( - default_factory=datetime.now, description='The timestamp of the message' - ) + timestamp: datetime = Field(default_factory=utc_now, description='The timestamp of the message') source_description: str = Field( default='', description='The description of the source of the message' ) diff --git a/server/graph_service/routers/ingest.py b/server/graph_service/routers/ingest.py index 8337c796..bd03a9f7 100644 --- a/server/graph_service/routers/ingest.py +++ b/server/graph_service/routers/ingest.py @@ -4,7 +4,7 @@ from functools import partial from fastapi import APIRouter, FastAPI, status from graphiti_core.nodes import EpisodeType # type: ignore -from graphiti_core.utils import clear_data # type: ignore +from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result from graph_service.zep_graphiti import ZepGraphitiDep diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 9f34e4a6..7bb30ba0 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -15,7 +15,7 @@ limitations under the License. """ import os -from datetime import datetime, timedelta, timezone +from datetime import timedelta import pytest from dotenv import load_dotenv @@ -23,6 +23,7 @@ from dotenv import load_dotenv from graphiti_core.edges import EntityEdge from graphiti_core.llm_client import LLMConfig, OpenAIClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.temporal_operations import ( extract_edge_dates, get_edge_contradictions, @@ -42,7 +43,7 @@ def setup_llm_client(): def create_test_data(): - now = datetime.now(timezone.utc) + now = utc_now() # Create edges existing_edge = EntityEdge( @@ -131,7 +132,7 @@ async def test_get_edge_contradictions_multiple_existing(): # Helper function to create more complex test data def create_complex_test_data(): - now = datetime.now(timezone.utc) + now = utc_now() # Create nodes node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1') @@ -191,7 +192,7 @@ async def test_invalidate_edges_complex(): name='DISLIKES', fact='Alice dislikes Bob', group_id='1', - created_at=datetime.now(timezone.utc), + created_at=utc_now(), ) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) @@ -213,7 +214,7 @@ async def test_get_edge_contradictions_temporal_update(): name='LEFT_JOB', fact='Bob no longer works at at Company XYZ', group_id='1', - created_at=datetime.now(timezone.utc), + created_at=utc_now(), ) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) @@ -235,7 +236,7 @@ async def test_get_edge_contradictions_no_effect(): name='APPLIED_TO', fact='Charlie applied to Company XYZ', group_id='1', - created_at=datetime.now(timezone.utc), + created_at=utc_now(), ) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) @@ -256,7 +257,7 @@ async def test_invalidate_edges_partial_update(): name='CHANGED_POSITION', fact='Bob changed his position at Company XYZ', group_id='1', - created_at=datetime.now(timezone.utc), + created_at=utc_now(), ) invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges) @@ -265,7 +266,7 @@ async def test_invalidate_edges_partial_update(): def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]: - now = datetime.now(timezone.utc) + now = utc_now() previous_episodes = [ EpisodicNode( @@ -314,7 +315,7 @@ async def test_extract_edge_dates(): name='LEFT_JOB', fact='Bob no longer works at Company XYZ', group_id='1', - created_at=datetime.now(timezone.utc), + created_at=utc_now(), ) valid_at, invalid_at = await extract_edge_dates(