add_fact endpoint (#207)
* add_fact endpoint * bump version * add edge invalidation * update
This commit is contained in:
parent
6536401c8c
commit
3199e893ed
17 changed files with 196 additions and 87 deletions
|
|
@ -19,7 +19,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -78,7 +78,7 @@ async def add_messages(client: Graphiti):
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
episode_body=message,
|
episode_body=message,
|
||||||
source=EpisodeType.message,
|
source=EpisodeType.message,
|
||||||
reference_time=datetime.now(),
|
reference_time=datetime.now(timezone.utc),
|
||||||
source_description='Shoe conversation',
|
source_description='Shoe conversation',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -105,7 +105,7 @@ async def ingest_products_data(client: Graphiti):
|
||||||
content=str(product),
|
content=str(product),
|
||||||
source_description='Allbirds products',
|
source_description='Allbirds products',
|
||||||
source=EpisodeType.json,
|
source=EpisodeType.json,
|
||||||
reference_time=datetime.now(),
|
reference_time=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
for i, product in enumerate(products)
|
for i, product in enumerate(products)
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -45,7 +45,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
|
||||||
ParsedMscMessage(
|
ParsedMscMessage(
|
||||||
speaker_name=speakers[speaker_idx],
|
speaker_name=speakers[speaker_idx],
|
||||||
content=content,
|
content=content,
|
||||||
actual_timestamp=datetime.now(),
|
actual_timestamp=datetime.now(timezone.utc),
|
||||||
group_id=str(i),
|
group_id=str(i),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -60,7 +60,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
|
||||||
ParsedMscMessage(
|
ParsedMscMessage(
|
||||||
speaker_name=speakers[speaker_idx],
|
speaker_name=speakers[speaker_idx],
|
||||||
content=content,
|
content=content,
|
||||||
actual_timestamp=datetime.now(),
|
actual_timestamp=datetime.now(timezone.utc),
|
||||||
group_id=str(i),
|
group_id=str(i),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -61,7 +61,7 @@ def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[Par
|
||||||
break
|
break
|
||||||
|
|
||||||
# Calculate the start time
|
# Calculate the start time
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
podcast_start_time = now - last_timestamp
|
podcast_start_time = now - last_timestamp
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
@ -63,7 +63,7 @@ async def main():
|
||||||
messages = get_wizard_of_oz_messages()
|
messages = get_wizard_of_oz_messages()
|
||||||
print(messages)
|
print(messages)
|
||||||
print(len(messages))
|
print(len(messages))
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
# episodes: list[BulkEpisode] = [
|
# episodes: list[BulkEpisode] = [
|
||||||
# BulkEpisode(
|
# BulkEpisode(
|
||||||
# name=f'Chapter {i + 1}',
|
# name=f'Chapter {i + 1}',
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -65,7 +65,9 @@ from graphiti_core.utils.maintenance.community_operations import (
|
||||||
update_community,
|
update_community,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
|
dedupe_extracted_edge,
|
||||||
extract_edges,
|
extract_edges,
|
||||||
|
resolve_edge_contradictions,
|
||||||
resolve_extracted_edges,
|
resolve_extracted_edges,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
|
|
@ -76,6 +78,7 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
resolve_extracted_nodes,
|
resolve_extracted_nodes,
|
||||||
)
|
)
|
||||||
|
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -312,7 +315,7 @@ class Graphiti:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
entity_edges: list[EntityEdge] = []
|
entity_edges: list[EntityEdge] = []
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
previous_episodes = await self.retrieve_episodes(
|
previous_episodes = await self.retrieve_episodes(
|
||||||
reference_time, last_n=3, group_ids=[group_id]
|
reference_time, last_n=3, group_ids=[group_id]
|
||||||
|
|
@ -448,7 +451,6 @@ class Graphiti:
|
||||||
|
|
||||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||||
|
|
||||||
# Future optimization would be using batch operations to save nodes and edges
|
|
||||||
if not self.store_raw_episode_content:
|
if not self.store_raw_episode_content:
|
||||||
episode.content = ''
|
episode.content = ''
|
||||||
|
|
||||||
|
|
@ -511,7 +513,7 @@ class Graphiti:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
start = time()
|
start = time()
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
|
|
@ -760,3 +762,36 @@ class Graphiti:
|
||||||
communities = await get_communities_by_nodes(self.driver, nodes)
|
communities = await get_communities_by_nodes(self.driver, nodes)
|
||||||
|
|
||||||
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
||||||
|
|
||||||
|
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
|
||||||
|
if source_node.name_embedding is None:
|
||||||
|
await source_node.generate_name_embedding(self.embedder)
|
||||||
|
if target_node.name_embedding is None:
|
||||||
|
await target_node.generate_name_embedding(self.embedder)
|
||||||
|
if edge.fact_embedding is None:
|
||||||
|
await edge.generate_embedding(self.embedder)
|
||||||
|
|
||||||
|
resolved_nodes, _ = await resolve_extracted_nodes(
|
||||||
|
self.llm_client,
|
||||||
|
[source_node, target_node],
|
||||||
|
[
|
||||||
|
await get_relevant_nodes([source_node], self.driver),
|
||||||
|
await get_relevant_nodes([target_node], self.driver),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
related_edges = await get_relevant_edges(
|
||||||
|
self.driver,
|
||||||
|
[edge],
|
||||||
|
source_node_uuid=resolved_nodes[0].uuid,
|
||||||
|
target_node_uuid=resolved_nodes[1].uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_edge = await dedupe_extracted_edge(self.llm_client, edge, related_edges)
|
||||||
|
|
||||||
|
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
|
||||||
|
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
||||||
|
|
||||||
|
await add_nodes_and_edges_bulk(
|
||||||
|
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
EPISODIC_EDGE_SAVE = """
|
EPISODIC_EDGE_SAVE = """
|
||||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||||
MATCH (node:Entity {uuid: $entity_uuid})
|
MATCH (node:Entity {uuid: $entity_uuid})
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
EPISODIC_NODE_SAVE = """
|
EPISODIC_NODE_SAVE = """
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -78,7 +78,7 @@ class Node(BaseModel, ABC):
|
||||||
name: str = Field(description='name of the node')
|
name: str = Field(description='name of the node')
|
||||||
group_id: str = Field(description='partition of the graph')
|
group_id: str = Field(description='partition of the graph')
|
||||||
labels: list[str] = Field(default_factory=list)
|
labels: list[str] = Field(default_factory=list)
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: AsyncDriver): ...
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||||
|
|
@ -385,7 +385,7 @@ async def extract_edge_dates_bulk(
|
||||||
edge.valid_at = valid_at
|
edge.valid_at = valid_at
|
||||||
edge.invalid_at = invalid_at
|
edge.invalid_at = invalid_at
|
||||||
if edge.invalid_at:
|
if edge.invalid_at:
|
||||||
edge.expired_at = datetime.now()
|
edge.expired_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -178,7 +178,7 @@ async def build_community(
|
||||||
|
|
||||||
summary = summaries[0]
|
summary = summaries[0]
|
||||||
name = await generate_summary_description(llm_client, summary)
|
name = await generate_summary_description(llm_client, summary)
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
community_node = CommunityNode(
|
community_node = CommunityNode(
|
||||||
name=name,
|
name=name,
|
||||||
group_id=community_cluster[0].group_id,
|
group_id=community_cluster[0].group_id,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
@ -110,7 +110,7 @@ async def extract_edges(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
fact=edge_data['fact'],
|
fact=edge_data['fact'],
|
||||||
episodes=[episode.uuid],
|
episodes=[episode.uuid],
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
valid_at=None,
|
valid_at=None,
|
||||||
invalid_at=None,
|
invalid_at=None,
|
||||||
)
|
)
|
||||||
|
|
@ -205,39 +205,9 @@ async def resolve_extracted_edges(
|
||||||
return resolved_edges, invalidated_edges
|
return resolved_edges, invalidated_edges
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_edge(
|
def resolve_edge_contradictions(
|
||||||
llm_client: LLMClient,
|
resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
|
||||||
extracted_edge: EntityEdge,
|
) -> list[EntityEdge]:
|
||||||
related_edges: list[EntityEdge],
|
|
||||||
existing_edges: list[EntityEdge],
|
|
||||||
current_episode: EpisodicNode,
|
|
||||||
previous_episodes: list[EpisodicNode],
|
|
||||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
|
||||||
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
|
||||||
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
|
||||||
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
|
||||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
|
||||||
)
|
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
|
||||||
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
|
||||||
if invalid_at is not None and resolved_edge.expired_at is None:
|
|
||||||
resolved_edge.expired_at = now
|
|
||||||
|
|
||||||
# Determine if the new_edge needs to be expired
|
|
||||||
if resolved_edge.expired_at is None:
|
|
||||||
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
|
||||||
for candidate in invalidation_candidates:
|
|
||||||
if (
|
|
||||||
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
|
||||||
) and candidate.valid_at > resolved_edge.valid_at:
|
|
||||||
# Expire new edge since we have information about more recent events
|
|
||||||
resolved_edge.invalid_at = candidate.valid_at
|
|
||||||
resolved_edge.expired_at = now
|
|
||||||
break
|
|
||||||
|
|
||||||
# Determine which contradictory edges need to be expired
|
# Determine which contradictory edges need to be expired
|
||||||
invalidated_edges: list[EntityEdge] = []
|
invalidated_edges: list[EntityEdge] = []
|
||||||
for edge in invalidation_candidates:
|
for edge in invalidation_candidates:
|
||||||
|
|
@ -259,9 +229,50 @@ async def resolve_extracted_edge(
|
||||||
and edge.valid_at < resolved_edge.valid_at
|
and edge.valid_at < resolved_edge.valid_at
|
||||||
):
|
):
|
||||||
edge.invalid_at = resolved_edge.valid_at
|
edge.invalid_at = resolved_edge.valid_at
|
||||||
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
|
edge.expired_at = (
|
||||||
|
edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
invalidated_edges.append(edge)
|
invalidated_edges.append(edge)
|
||||||
|
|
||||||
|
return invalidated_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_extracted_edge(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
extracted_edge: EntityEdge,
|
||||||
|
related_edges: list[EntityEdge],
|
||||||
|
existing_edges: list[EntityEdge],
|
||||||
|
current_episode: EpisodicNode,
|
||||||
|
previous_episodes: list[EpisodicNode],
|
||||||
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||||
|
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
||||||
|
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
||||||
|
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
||||||
|
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
||||||
|
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
||||||
|
if invalid_at is not None and resolved_edge.expired_at is None:
|
||||||
|
resolved_edge.expired_at = now
|
||||||
|
|
||||||
|
# Determine if the new_edge needs to be expired
|
||||||
|
if resolved_edge.expired_at is None:
|
||||||
|
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
||||||
|
for candidate in invalidation_candidates:
|
||||||
|
if (
|
||||||
|
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
||||||
|
) and candidate.valid_at > resolved_edge.valid_at:
|
||||||
|
# Expire new edge since we have information about more recent events
|
||||||
|
resolved_edge.invalid_at = candidate.valid_at
|
||||||
|
resolved_edge.expired_at = now
|
||||||
|
break
|
||||||
|
|
||||||
|
# Determine which contradictory edges need to be expired
|
||||||
|
invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
|
||||||
|
|
||||||
return resolved_edge, invalidated_edges
|
return resolved_edge, invalidated_edges
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -113,7 +113,7 @@ async def extract_nodes(
|
||||||
group_id=episode.group_id,
|
group_id=episode.group_id,
|
||||||
labels=node_data['labels'],
|
labels=node_data['labels'],
|
||||||
summary=node_data['summary'],
|
summary=node_data['summary'],
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
new_nodes.append(new_node)
|
new_nodes.append(new_node)
|
||||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.3.21"
|
version = "0.4.0"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status
|
||||||
|
|
||||||
|
|
@ -36,7 +36,7 @@ async def get_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
|
||||||
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
|
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
|
||||||
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
|
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
|
||||||
episodes = await graphiti.retrieve_episodes(
|
episodes = await graphiti.retrieve_episodes(
|
||||||
group_ids=[group_id], last_n=last_n, reference_time=datetime.now()
|
group_ids=[group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
|
||||||
)
|
)
|
||||||
return episodes
|
return episodes
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -66,7 +66,39 @@ def setup_logging():
|
||||||
async def test_graphiti_init():
|
async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
alice_node = EntityNode(
|
||||||
|
name='Alice',
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
summary='Alice summary',
|
||||||
|
group_id='test',
|
||||||
|
)
|
||||||
|
|
||||||
|
bob_node = EntityNode(
|
||||||
|
name='Bob',
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
summary='Bob summary',
|
||||||
|
group_id='test',
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_edge = EntityEdge(
|
||||||
|
source_node_uuid=alice_node.uuid,
|
||||||
|
target_node_uuid=bob_node.uuid,
|
||||||
|
created_at=now,
|
||||||
|
name='likes',
|
||||||
|
fact='Alice likes Bob',
|
||||||
|
episodes=[],
|
||||||
|
expired_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
group_id='test',
|
||||||
|
)
|
||||||
|
|
||||||
|
await graphiti.add_triplet(alice_node, entity_edge, bob_node)
|
||||||
|
|
||||||
|
episodes = await graphiti.retrieve_episodes(datetime.now(timezone.utc), group_ids=None)
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
|
|
||||||
results = await graphiti._search(
|
results = await graphiti._search(
|
||||||
|
|
@ -92,7 +124,7 @@ async def test_graph_integration():
|
||||||
embedder = client.embedder
|
embedder = client.embedder
|
||||||
driver = client.driver
|
driver = client.driver
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
episode = EpisodicNode(
|
episode = EpisodicNode(
|
||||||
name='test_episode',
|
name='test_episode',
|
||||||
labels=[],
|
labels=[],
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -23,7 +23,7 @@ def mock_extracted_edge():
|
||||||
group_id='group_1',
|
group_id='group_1',
|
||||||
fact='Test fact',
|
fact='Test fact',
|
||||||
episodes=['episode_1'],
|
episodes=['episode_1'],
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
valid_at=None,
|
valid_at=None,
|
||||||
invalid_at=None,
|
invalid_at=None,
|
||||||
)
|
)
|
||||||
|
|
@ -39,8 +39,8 @@ def mock_related_edges():
|
||||||
group_id='group_1',
|
group_id='group_1',
|
||||||
fact='Related fact',
|
fact='Related fact',
|
||||||
episodes=['episode_2'],
|
episodes=['episode_2'],
|
||||||
created_at=datetime.now() - timedelta(days=1),
|
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||||
valid_at=datetime.now() - timedelta(days=1),
|
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||||
invalid_at=None,
|
invalid_at=None,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -56,8 +56,8 @@ def mock_existing_edges():
|
||||||
group_id='group_1',
|
group_id='group_1',
|
||||||
fact='Existing fact',
|
fact='Existing fact',
|
||||||
episodes=['episode_3'],
|
episodes=['episode_3'],
|
||||||
created_at=datetime.now() - timedelta(days=2),
|
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||||
valid_at=datetime.now() - timedelta(days=2),
|
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||||
invalid_at=None,
|
invalid_at=None,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -68,7 +68,7 @@ def mock_current_episode():
|
||||||
return EpisodicNode(
|
return EpisodicNode(
|
||||||
uuid='episode_1',
|
uuid='episode_1',
|
||||||
content='Current episode content',
|
content='Current episode content',
|
||||||
valid_at=datetime.now(),
|
valid_at=datetime.now(timezone.utc),
|
||||||
name='Current Episode',
|
name='Current Episode',
|
||||||
group_id='group_1',
|
group_id='group_1',
|
||||||
source='message',
|
source='message',
|
||||||
|
|
@ -82,7 +82,7 @@ def mock_previous_episodes():
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
uuid='episode_2',
|
uuid='episode_2',
|
||||||
content='Previous episode content',
|
content='Previous episode content',
|
||||||
valid_at=datetime.now() - timedelta(days=1),
|
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||||
name='Previous Episode',
|
name='Previous Episode',
|
||||||
group_id='group_1',
|
group_id='group_1',
|
||||||
source='message',
|
source='message',
|
||||||
|
|
@ -144,8 +144,8 @@ async def test_resolve_extracted_edge_with_dates(
|
||||||
mock_previous_episodes,
|
mock_previous_episodes,
|
||||||
monkeypatch: MonkeyPatch,
|
monkeypatch: MonkeyPatch,
|
||||||
):
|
):
|
||||||
valid_at = datetime.now() - timedelta(days=1)
|
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
invalid_at = datetime.now() + timedelta(days=1)
|
invalid_at = datetime.now(timezone.utc) + timedelta(days=1)
|
||||||
|
|
||||||
# Mock the function calls
|
# Mock the function calls
|
||||||
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||||
|
|
@ -189,7 +189,7 @@ async def test_resolve_extracted_edge_with_invalidation(
|
||||||
mock_previous_episodes,
|
mock_previous_episodes,
|
||||||
monkeypatch: MonkeyPatch,
|
monkeypatch: MonkeyPatch,
|
||||||
):
|
):
|
||||||
valid_at = datetime.now() - timedelta(days=1)
|
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
mock_extracted_edge.valid_at = valid_at
|
mock_extracted_edge.valid_at = valid_at
|
||||||
|
|
||||||
invalidation_candidate = EntityEdge(
|
invalidation_candidate = EntityEdge(
|
||||||
|
|
@ -199,8 +199,8 @@ async def test_resolve_extracted_edge_with_invalidation(
|
||||||
group_id='group_1',
|
group_id='group_1',
|
||||||
fact='Invalidation candidate fact',
|
fact='Invalidation candidate fact',
|
||||||
episodes=['episode_4'],
|
episodes=['episode_4'],
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
valid_at=datetime.now() - timedelta(days=2),
|
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||||
invalid_at=None,
|
invalid_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,10 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pytz import UTC
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||||
|
|
@ -43,7 +42,7 @@ def setup_llm_client():
|
||||||
|
|
||||||
|
|
||||||
def create_test_data():
|
def create_test_data():
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Create edges
|
# Create edges
|
||||||
existing_edge = EntityEdge(
|
existing_edge = EntityEdge(
|
||||||
|
|
@ -132,7 +131,7 @@ async def test_get_edge_contradictions_multiple_existing():
|
||||||
|
|
||||||
# Helper function to create more complex test data
|
# Helper function to create more complex test data
|
||||||
def create_complex_test_data():
|
def create_complex_test_data():
|
||||||
now = datetime.now()
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Create nodes
|
# Create nodes
|
||||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
|
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
|
||||||
|
|
@ -192,7 +191,7 @@ async def test_invalidate_edges_complex():
|
||||||
name='DISLIKES',
|
name='DISLIKES',
|
||||||
fact='Alice dislikes Bob',
|
fact='Alice dislikes Bob',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -214,7 +213,7 @@ async def test_get_edge_contradictions_temporal_update():
|
||||||
name='LEFT_JOB',
|
name='LEFT_JOB',
|
||||||
fact='Bob no longer works at at Company XYZ',
|
fact='Bob no longer works at at Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -236,7 +235,7 @@ async def test_get_edge_contradictions_no_effect():
|
||||||
name='APPLIED_TO',
|
name='APPLIED_TO',
|
||||||
fact='Charlie applied to Company XYZ',
|
fact='Charlie applied to Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -257,7 +256,7 @@ async def test_invalidate_edges_partial_update():
|
||||||
name='CHANGED_POSITION',
|
name='CHANGED_POSITION',
|
||||||
fact='Bob changed his position at Company XYZ',
|
fact='Bob changed his position at Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||||
|
|
@ -266,7 +265,7 @@ async def test_invalidate_edges_partial_update():
|
||||||
|
|
||||||
|
|
||||||
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
previous_episodes = [
|
previous_episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
|
|
@ -315,7 +314,7 @@ async def test_extract_edge_dates():
|
||||||
name='LEFT_JOB',
|
name='LEFT_JOB',
|
||||||
fact='Bob no longer works at Company XYZ',
|
fact='Bob no longer works at Company XYZ',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
created_at=datetime.now(UTC),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_at, invalid_at = await extract_edge_dates(
|
valid_at, invalid_at = await extract_edge_dates(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue