refactor: use utc_now() for consistent UTC datetime handling (#234)
* ensure utc timezones * fix: dep cycle --------- Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
parent
732b2f328d
commit
445dccc021
12 changed files with 97 additions and 60 deletions
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -43,10 +43,6 @@ from graphiti_core.search.search_utils import (
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
get_relevant_nodes,
|
get_relevant_nodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils import (
|
|
||||||
build_episodic_edges,
|
|
||||||
retrieve_episodes,
|
|
||||||
)
|
|
||||||
from graphiti_core.utils.bulk_utils import (
|
from graphiti_core.utils.bulk_utils import (
|
||||||
RawEpisode,
|
RawEpisode,
|
||||||
add_nodes_and_edges_bulk,
|
add_nodes_and_edges_bulk,
|
||||||
|
|
@ -57,12 +53,14 @@ from graphiti_core.utils.bulk_utils import (
|
||||||
resolve_edge_pointers,
|
resolve_edge_pointers,
|
||||||
retrieve_previous_episodes_bulk,
|
retrieve_previous_episodes_bulk,
|
||||||
)
|
)
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
from graphiti_core.utils.maintenance.community_operations import (
|
from graphiti_core.utils.maintenance.community_operations import (
|
||||||
build_communities,
|
build_communities,
|
||||||
remove_communities,
|
remove_communities,
|
||||||
update_community,
|
update_community,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
|
build_episodic_edges,
|
||||||
dedupe_extracted_edge,
|
dedupe_extracted_edge,
|
||||||
extract_edges,
|
extract_edges,
|
||||||
resolve_edge_contradictions,
|
resolve_edge_contradictions,
|
||||||
|
|
@ -71,6 +69,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
EPISODE_WINDOW_LEN,
|
EPISODE_WINDOW_LEN,
|
||||||
build_indices_and_constraints,
|
build_indices_and_constraints,
|
||||||
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.node_operations import (
|
from graphiti_core.utils.maintenance.node_operations import (
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
|
|
@ -313,7 +312,7 @@ class Graphiti:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
entity_edges: list[EntityEdge] = []
|
entity_edges: list[EntityEdge] = []
|
||||||
now = datetime.now(timezone.utc)
|
now = utc_now()
|
||||||
|
|
||||||
previous_episodes = await self.retrieve_episodes(
|
previous_episodes = await self.retrieve_episodes(
|
||||||
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
||||||
|
|
@ -522,7 +521,7 @@ class Graphiti:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
start = time()
|
start = time()
|
||||||
now = datetime.now(timezone.utc)
|
now = utc_now()
|
||||||
|
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -34,6 +34,7 @@ from graphiti_core.models.nodes.node_db_queries import (
|
||||||
ENTITY_NODE_SAVE,
|
ENTITY_NODE_SAVE,
|
||||||
EPISODIC_NODE_SAVE,
|
EPISODIC_NODE_SAVE,
|
||||||
)
|
)
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -79,7 +80,7 @@ class Node(BaseModel, ABC):
|
||||||
name: str = Field(description='name of the node')
|
name: str = Field(description='name of the node')
|
||||||
group_id: str = Field(description='partition of the graph')
|
group_id: str = Field(description='partition of the graph')
|
||||||
labels: list[str] = Field(default_factory=list)
|
labels: list[str] = Field(default_factory=list)
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: utc_now())
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: AsyncDriver): ...
|
||||||
|
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
from .maintenance import (
|
|
||||||
build_episodic_edges,
|
|
||||||
clear_data,
|
|
||||||
extract_edges,
|
|
||||||
extract_nodes,
|
|
||||||
retrieve_episodes,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'extract_edges',
|
|
||||||
'build_episodic_edges',
|
|
||||||
'extract_nodes',
|
|
||||||
'clear_data',
|
|
||||||
'retrieve_episodes',
|
|
||||||
]
|
|
||||||
|
|
@ -18,7 +18,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||||
|
|
@ -37,14 +37,17 @@ from graphiti_core.models.nodes.node_db_queries import (
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
||||||
from graphiti_core.utils import retrieve_episodes
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
dedupe_edge_list,
|
dedupe_edge_list,
|
||||||
dedupe_extracted_edges,
|
dedupe_extracted_edges,
|
||||||
extract_edges,
|
extract_edges,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
|
EPISODE_WINDOW_LEN,
|
||||||
|
retrieve_episodes,
|
||||||
|
)
|
||||||
from graphiti_core.utils.maintenance.node_operations import (
|
from graphiti_core.utils.maintenance.node_operations import (
|
||||||
dedupe_extracted_nodes,
|
dedupe_extracted_nodes,
|
||||||
dedupe_node_list,
|
dedupe_node_list,
|
||||||
|
|
@ -385,7 +388,7 @@ async def extract_edge_dates_bulk(
|
||||||
edge.valid_at = valid_at
|
edge.valid_at = valid_at
|
||||||
edge.invalid_at = invalid_at
|
edge.invalid_at = invalid_at
|
||||||
if edge.invalid_at:
|
if edge.invalid_at:
|
||||||
edge.expired_at = datetime.now(timezone.utc)
|
edge.expired_at = utc_now()
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
|
||||||
42
graphiti_core/utils/datetime_utils.py
Normal file
42
graphiti_core/utils/datetime_utils.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
def utc_now() -> datetime:
|
||||||
|
"""Returns the current UTC datetime with timezone information."""
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_utc(dt: datetime | None) -> datetime | None:
|
||||||
|
"""
|
||||||
|
Ensures a datetime is timezone-aware and in UTC.
|
||||||
|
If the datetime is naive (no timezone), assumes it's in UTC.
|
||||||
|
If the datetime has a different timezone, converts it to UTC.
|
||||||
|
Returns None if input is None.
|
||||||
|
"""
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
# If datetime is naive, assume it's UTC
|
||||||
|
return dt.replace(tzinfo=timezone.utc)
|
||||||
|
elif dt.tzinfo != timezone.utc:
|
||||||
|
# If datetime has a different timezone, convert to UTC
|
||||||
|
return dt.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
return dt
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -17,6 +16,7 @@ from graphiti_core.nodes import (
|
||||||
)
|
)
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
|
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||||
|
|
||||||
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
||||||
|
|
@ -180,7 +180,7 @@ async def build_community(
|
||||||
|
|
||||||
summary = summaries[0]
|
summary = summaries[0]
|
||||||
name = await generate_summary_description(llm_client, summary)
|
name = await generate_summary_description(llm_client, summary)
|
||||||
now = datetime.now(timezone.utc)
|
now = utc_now()
|
||||||
community_node = CommunityNode(
|
community_node = CommunityNode(
|
||||||
name=name,
|
name=name,
|
||||||
group_id=community_cluster[0].group_id,
|
group_id=community_cluster[0].group_id,
|
||||||
|
|
@ -307,7 +307,7 @@ async def update_community(
|
||||||
community.name = new_name
|
community.name = new_name
|
||||||
|
|
||||||
if is_new:
|
if is_new:
|
||||||
community_edge = (build_community_edges([entity], community, datetime.now(timezone.utc)))[0]
|
community_edge = (build_community_edges([entity], community, utc_now()))[0]
|
||||||
await community_edge.save(driver)
|
await community_edge.save(driver)
|
||||||
|
|
||||||
await community.generate_name_embedding(embedder)
|
await community.generate_name_embedding(embedder)
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||||
|
|
@ -26,6 +26,7 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
||||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||||
extract_edge_dates,
|
extract_edge_dates,
|
||||||
get_edge_contradictions,
|
get_edge_contradictions,
|
||||||
|
|
@ -132,7 +133,7 @@ async def extract_edges(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
fact=edge_data.get('fact', ''),
|
fact=edge_data.get('fact', ''),
|
||||||
episodes=[episode.uuid],
|
episodes=[episode.uuid],
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=utc_now(),
|
||||||
valid_at=None,
|
valid_at=None,
|
||||||
invalid_at=None,
|
invalid_at=None,
|
||||||
)
|
)
|
||||||
|
|
@ -251,9 +252,7 @@ def resolve_edge_contradictions(
|
||||||
and edge.valid_at < resolved_edge.valid_at
|
and edge.valid_at < resolved_edge.valid_at
|
||||||
):
|
):
|
||||||
edge.invalid_at = resolved_edge.valid_at
|
edge.invalid_at = resolved_edge.valid_at
|
||||||
edge.expired_at = (
|
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
|
||||||
edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
invalidated_edges.append(edge)
|
invalidated_edges.append(edge)
|
||||||
|
|
||||||
return invalidated_edges
|
return invalidated_edges
|
||||||
|
|
@ -273,11 +272,12 @@ async def resolve_extracted_edge(
|
||||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||||
)
|
)
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = utc_now()
|
||||||
|
|
||||||
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
resolved_edge.valid_at = valid_at if valid_at else resolved_edge.valid_at
|
||||||
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
resolved_edge.invalid_at = invalid_at if invalid_at else resolved_edge.invalid_at
|
||||||
if invalid_at is not None and resolved_edge.expired_at is None:
|
|
||||||
|
if invalid_at and not resolved_edge.expired_at:
|
||||||
resolved_edge.expired_at = now
|
resolved_edge.expired_at = now
|
||||||
|
|
||||||
# Determine if the new_edge needs to be expired
|
# Determine if the new_edge needs to be expired
|
||||||
|
|
@ -285,8 +285,12 @@ async def resolve_extracted_edge(
|
||||||
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
||||||
for candidate in invalidation_candidates:
|
for candidate in invalidation_candidates:
|
||||||
if (
|
if (
|
||||||
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
candidate.valid_at
|
||||||
) and candidate.valid_at > resolved_edge.valid_at:
|
and resolved_edge.valid_at
|
||||||
|
and candidate.valid_at.tzinfo
|
||||||
|
and resolved_edge.valid_at.tzinfo
|
||||||
|
and candidate.valid_at > resolved_edge.valid_at
|
||||||
|
):
|
||||||
# Expire new edge since we have information about more recent events
|
# Expire new edge since we have information about more recent events
|
||||||
resolved_edge.invalid_at = candidate.valid_at
|
resolved_edge.invalid_at = candidate.valid_at
|
||||||
resolved_edge.expired_at = now
|
resolved_edge.expired_at = now
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
||||||
|
|
@ -26,6 +25,7 @@ from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||||
from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
|
from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
|
||||||
from graphiti_core.prompts.summarize_nodes import Summary
|
from graphiti_core.prompts.summarize_nodes import Summary
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -155,7 +155,7 @@ async def extract_nodes(
|
||||||
group_id=episode.group_id,
|
group_id=episode.group_id,
|
||||||
labels=['Entity'],
|
labels=['Entity'],
|
||||||
summary='',
|
summary='',
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=utc_now(),
|
||||||
)
|
)
|
||||||
new_nodes.append(new_node)
|
new_nodes.append(new_node)
|
||||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ You may obtain a copy of the License at
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
@ -24,6 +24,7 @@ from graphiti_core.nodes import EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.extract_edge_dates import EdgeDates
|
from graphiti_core.prompts.extract_edge_dates import EdgeDates
|
||||||
from graphiti_core.prompts.invalidate_edges import InvalidatedEdges
|
from graphiti_core.prompts.invalidate_edges import InvalidatedEdges
|
||||||
|
from graphiti_core.utils.datetime_utils import ensure_utc
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -52,13 +53,15 @@ async def extract_edge_dates(
|
||||||
|
|
||||||
if valid_at:
|
if valid_at:
|
||||||
try:
|
try:
|
||||||
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
|
valid_at_datetime = ensure_utc(datetime.fromisoformat(valid_at.replace('Z', '+00:00')))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
||||||
|
|
||||||
if invalid_at:
|
if invalid_at:
|
||||||
try:
|
try:
|
||||||
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
invalid_at_datetime = ensure_utc(
|
||||||
|
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,9 +22,7 @@ class Message(BaseModel):
|
||||||
role: str | None = Field(
|
role: str | None = Field(
|
||||||
description='The custom role of the message to be used alongside role_type (user name, bot name, etc.)',
|
description='The custom role of the message to be used alongside role_type (user name, bot name, etc.)',
|
||||||
)
|
)
|
||||||
timestamp: datetime = Field(
|
timestamp: datetime = Field(default_factory=utc_now, description='The timestamp of the message')
|
||||||
default_factory=datetime.now, description='The timestamp of the message'
|
|
||||||
)
|
|
||||||
source_description: str = Field(
|
source_description: str = Field(
|
||||||
default='', description='The description of the source of the message'
|
default='', description='The description of the source of the message'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from functools import partial
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI, status
|
from fastapi import APIRouter, FastAPI, status
|
||||||
from graphiti_core.nodes import EpisodeType # type: ignore
|
from graphiti_core.nodes import EpisodeType # type: ignore
|
||||||
from graphiti_core.utils import clear_data # type: ignore
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore
|
||||||
|
|
||||||
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
|
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
|
||||||
from graph_service.zep_graphiti import ZepGraphitiDep
|
from graph_service.zep_graphiti import ZepGraphitiDep
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import timedelta
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -23,6 +23,7 @@ from dotenv import load_dotenv
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||||
extract_edge_dates,
|
extract_edge_dates,
|
||||||
get_edge_contradictions,
|
get_edge_contradictions,
|
||||||
|
|
@ -42,7 +43,7 @@ def setup_llm_client():
|
||||||
|
|
||||||
|
|
||||||
def create_test_data():
|
def create_test_data():
|
||||||
now = datetime.now(timezone.utc)
|
now = utc_now()
|
||||||
|
|
||||||
# Create edges
|
# Create edges
|
||||||
existing_edge = EntityEdge(
|
existing_edge = EntityEdge(
|
||||||
|
|
@ -131,7 +132,7 @@ async def test_get_edge_contradictions_multiple_existing():
|
||||||
|
|
||||||
# Helper function to create more complex test data
|
# Helper function to create more complex test data
|
||||||
def create_complex_test_data():
|
def create_complex_test_data():
|
||||||
now = datetime.now(timezone.utc)
|
now = utc_now()
|
||||||
|
|
||||||
# Create nodes
|
# Create nodes
|
||||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
|
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
|
||||||
|
|
@ -191,7 +192,7 @@ async def test_invalidate_edges_complex():
|
||||||
name='DISLIKES',
|
name='DISLIKES',
|
||||||
fact='Alice dislikes Bob',
|
fact='Alice dislikes Bob',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=utc_now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -213,7 +214,7 @@ async def test_get_edge_contradictions_temporal_update():
|
||||||
name='LEFT_JOB',
|
name='LEFT_JOB',
|
||||||
fact='Bob no longer works at at Company XYZ',
|
fact='Bob no longer works at at Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=utc_now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -235,7 +236,7 @@ async def test_get_edge_contradictions_no_effect():
|
||||||
name='APPLIED_TO',
|
name='APPLIED_TO',
|
||||||
fact='Charlie applied to Company XYZ',
|
fact='Charlie applied to Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=utc_now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -256,7 +257,7 @@ async def test_invalidate_edges_partial_update():
|
||||||
name='CHANGED_POSITION',
|
name='CHANGED_POSITION',
|
||||||
fact='Bob changed his position at Company XYZ',
|
fact='Bob changed his position at Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=utc_now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -265,7 +266,7 @@ async def test_invalidate_edges_partial_update():
|
||||||
|
|
||||||
|
|
||||||
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
||||||
now = datetime.now(timezone.utc)
|
now = utc_now()
|
||||||
|
|
||||||
previous_episodes = [
|
previous_episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
|
|
@ -314,7 +315,7 @@ async def test_extract_edge_dates():
|
||||||
name='LEFT_JOB',
|
name='LEFT_JOB',
|
||||||
fact='Bob no longer works at Company XYZ',
|
fact='Bob no longer works at Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=utc_now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_at, invalid_at = await extract_edge_dates(
|
valid_at, invalid_at = await extract_edge_dates(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue