Eval updates (#404)

* update eval

* make format

* remove unused imports

* mypy
This commit is contained in:
Preston Rasmussen 2025-04-27 14:27:47 -04:00 committed by GitHub
parent 0b94e0e603
commit 7ee4e38616
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 78 additions and 52 deletions

View file

@ -134,9 +134,9 @@ async def retrieve_episodes(
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
""" """
group_id_filter: LiteralString = ( group_id_filter: LiteralString = (
'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' '\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
) )
source_filter: LiteralString = 'AND e.source = $source' if source is not None else '' source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
query: LiteralString = ( query: LiteralString = (
""" """
@ -161,7 +161,7 @@ async def retrieve_episodes(
result = await driver.execute_query( result = await driver.execute_query(
query, query,
reference_time=reference_time, reference_time=reference_time,
source=source, source=source.name if source is not None else None,
num_episodes=last_n, num_episodes=last_n,
group_ids=group_ids, group_ids=group_ids,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,

View file

@ -10,11 +10,10 @@ async def main():
) )
parser.add_argument( parser.add_argument(
'--multi-session', '--multi-session-count',
type=int, type=int,
nargs='+',
required=True, required=True,
help='List of integers representing multi-session values (e.g., 1 2 3)', help='Integer representing multi-session count',
) )
parser.add_argument('--session-length', type=int, required=True, help='Length of each session') parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
parser.add_argument( parser.add_argument(
@ -27,11 +26,13 @@ async def main():
if args.build_baseline: if args.build_baseline:
print('Running build_baseline_graph...') print('Running build_baseline_graph...')
await build_baseline_graph( await build_baseline_graph(
multi_session=args.multi_session, session_length=args.session_length multi_session_count=args.multi_session_count, session_length=args.session_length
) )
# Always call eval_graph # Always call eval_graph
result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length) result = await eval_graph(
multi_session_count=args.multi_session_count, session_length=args.session_length
)
print('Result of eval_graph:', result) print('Result of eval_graph:', result)

View file

@ -21,34 +21,24 @@ import pandas as pd
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.graphiti import AddEpisodeResults from graphiti_core.graphiti import AddEpisodeResults
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMConfig, OpenAIClient from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EpisodeType from graphiti_core.nodes import EpisodeType
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.eval import EvalAddEpisodeResults from graphiti_core.prompts.eval import EvalAddEpisodeResults
from graphiti_core.utils.maintenance import clear_data
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
async def build_graph( async def build_subgraph(
group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti graphiti: Graphiti,
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]: user_id: str,
# Get longmemeval dataset multi_session,
lme_dataset_option = ( multi_session_dates,
'data/longmemeval_data/longmemeval_oracle.json' # Can be _oracle, _s, or _m session_length: int,
) group_id_suffix: str,
lme_dataset_df = pd.read_json(lme_dataset_option) ) -> tuple[str, list[AddEpisodeResults], list[str]]:
add_episode_results: list[AddEpisodeResults] = []
add_episode_results: dict[str, list[AddEpisodeResults]] = {} add_episode_context: list[str] = []
add_episode_context: dict[str, list[str]] = {}
for multi_session_idx in multi_session:
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
await clear_data(graphiti.driver, [user_id])
add_episode_results[user_id] = []
add_episode_context[user_id] = []
message_count = 0 message_count = 0
for session_idx, session in enumerate(multi_session): for session_idx, session in enumerate(multi_session):
@ -74,17 +64,52 @@ async def build_graph(
for edge in results.edges: for edge in results.edges:
edge.fact_embedding = None edge.fact_embedding = None
add_episode_results[user_id].append(results) add_episode_results.append(results)
add_episode_context[user_id].append(msg['content']) add_episode_context.append(msg['content'])
return user_id, add_episode_results, add_episode_context
async def build_graph(
group_id_suffix: str, multi_session_count: int, session_length: int, graphiti: Graphiti
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
# Get longmemeval dataset
lme_dataset_option = (
'data/longmemeval_data/longmemeval_oracle.json' # Can be _oracle, _s, or _m
)
lme_dataset_df = pd.read_json(lme_dataset_option)
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
add_episode_context: dict[str, list[str]] = {}
subgraph_results: list[tuple[str, list[AddEpisodeResults], list[str]]] = await semaphore_gather(
*[
build_subgraph(
graphiti,
user_id='lme_oracle_experiment_user_' + str(multi_session_idx),
multi_session=lme_dataset_df['haystack_sessions'].iloc[multi_session_idx],
multi_session_dates=lme_dataset_df['haystack_dates'].iloc[multi_session_idx],
session_length=session_length,
group_id_suffix=group_id_suffix,
)
for multi_session_idx in range(multi_session_count)
]
)
for user_id, episode_results, episode_context in subgraph_results:
add_episode_results[user_id] = episode_results
add_episode_context[user_id] = episode_context
return add_episode_results, add_episode_context return add_episode_results, add_episode_context
async def build_baseline_graph(multi_session: list[int], session_length: int): async def build_baseline_graph(multi_session_count: int, session_length: int):
# Use gpt-4o for graph building baseline # Use gpt-4.1-mini for graph building baseline
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o')) llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti) add_episode_results, _ = await build_graph(
'baseline', multi_session_count, session_length, graphiti
)
filename = 'baseline_graph_results.json' filename = 'baseline_graph_results.json'
@ -97,7 +122,7 @@ async def build_baseline_graph(multi_session: list[int], session_length: int):
json.dump(serializable_baseline_graph_results, file, indent=4, default=str) json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float: async def eval_graph(multi_session_count: int, session_length: int, llm_client=None) -> float:
if llm_client is None: if llm_client is None:
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini')) llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
@ -109,7 +134,7 @@ async def eval_graph(multi_session: list[int], session_length: int, llm_client=N
for key, value in baseline_results_raw.items() for key, value in baseline_results_raw.items()
} }
add_episode_results, add_episode_context = await build_graph( add_episode_results, add_episode_context = await build_graph(
'candidate', multi_session, session_length, graphiti 'candidate', multi_session_count, session_length, graphiti
) )
filename = 'candidate_graph_results.json' filename = 'candidate_graph_results.json'