79 lines
3.3 KiB
Python
79 lines
3.3 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Tuple
|
|
|
|
import pandas as pd
|
|
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.graphiti import AddEpisodeResults
|
|
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
|
from graphiti_core.nodes import EpisodeType
|
|
from graphiti_core.utils.maintenance import clear_data
|
|
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
|
|
|
|
|
|
async def build_graph(
|
|
multi_session: list[int], session_length: int, graphiti: Graphiti
|
|
) -> Tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
|
|
# Get longmemeval dataset
|
|
lme_dataset_option = 'data/longmemeval_oracle.json' # Can be _oracle, _s, or _m
|
|
lme_dataset_df = pd.read_json(lme_dataset_option)
|
|
|
|
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
|
|
add_episode_context: dict[str, list[str]] = {}
|
|
for multi_session_idx in multi_session:
|
|
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
|
|
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
|
|
|
|
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
|
|
await clear_data(graphiti.driver, [user_id])
|
|
|
|
add_episode_results[user_id] = []
|
|
add_episode_context[user_id] = []
|
|
|
|
for session_idx, session in enumerate(multi_session):
|
|
if session_idx >= session_length:
|
|
continue
|
|
for msx_idx, msg in enumerate(session):
|
|
date = multi_session_dates[session_idx] + ' UTC'
|
|
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
|
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
|
|
|
episode_body = f"{msg["role"]}: {msg["content"]}"
|
|
results = await graphiti.add_episode(
|
|
name=msg['name'],
|
|
episode_body=episode_body,
|
|
reference_time=date_string,
|
|
source=EpisodeType.message,
|
|
source_description='',
|
|
group_id=user_id,
|
|
)
|
|
|
|
add_episode_results[user_id].append(results)
|
|
return add_episode_results, add_episode_context
|
|
|
|
|
|
async def build_baseline_graph(multi_session: list[int], session_length: int):
|
|
# Use gpt-4o for graph building baseline
|
|
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o'))
|
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
|
|
|
add_episode_results, _ = await build_graph(multi_session, session_length, graphiti)
|
|
|
|
|
|
async def eval_graph(multi_session: list[int], session_length: int, llm_client=OpenAIClient()):
|
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
|
baseline_results: dict[str, list[AddEpisodeResults]] = {}
|
|
add_episode_results, add_episode_context = await build_graph(
|
|
multi_session, session_length, graphiti
|
|
)
|
|
|
|
for user_id in add_episode_results:
|
|
for baseline_result, add_episode_result, episodes in zip(
|
|
baseline_results[user_id], add_episode_results[user_id], add_episode_context[user_id]
|
|
):
|
|
context = {
|
|
'baseline': baseline_result,
|
|
'candidate': add_episode_result,
|
|
'message': episodes[0],
|
|
'previous_messages': episodes[1:],
|
|
}
|