diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index dd71e6e3..60f402a7 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -134,14 +134,14 @@ async def retrieve_episodes( list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. """ group_id_filter: LiteralString = ( - 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' + '\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' ) - source_filter: LiteralString = 'AND e.source = $source' if source is not None else '' + source_filter: LiteralString = '\nAND e.source = $source' if source is not None else '' query: LiteralString = ( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time + """ + group_id_filter + source_filter + """ @@ -161,7 +161,7 @@ async def retrieve_episodes( result = await driver.execute_query( query, reference_time=reference_time, - source=source, + source=source.name if source is not None else None, num_episodes=last_n, group_ids=group_ids, database_=DEFAULT_DATABASE, diff --git a/tests/evals/eval_cli.py b/tests/evals/eval_cli.py index 58d69e64..39dd6156 100644 --- a/tests/evals/eval_cli.py +++ b/tests/evals/eval_cli.py @@ -10,11 +10,10 @@ async def main(): ) parser.add_argument( - '--multi-session', + '--multi-session-count', type=int, - nargs='+', required=True, - help='List of integers representing multi-session values (e.g., 1 2 3)', + help='Integer representing multi-session count', ) parser.add_argument('--session-length', type=int, required=True, help='Length of each session') parser.add_argument( @@ -27,11 +26,13 @@ async def main(): if args.build_baseline: print('Running build_baseline_graph...') await build_baseline_graph( - multi_session=args.multi_session, session_length=args.session_length + multi_session_count=args.multi_session_count, session_length=args.session_length ) # Always call eval_graph - result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length) + result = await eval_graph( + multi_session_count=args.multi_session_count, session_length=args.session_length + ) print('Result of eval_graph:', result) diff --git a/tests/evals/eval_e2e_graph_building.py b/tests/evals/eval_e2e_graph_building.py index 5e3b2573..fce284ba 100644 --- a/tests/evals/eval_e2e_graph_building.py +++ b/tests/evals/eval_e2e_graph_building.py @@ -21,16 +21,57 @@ import pandas as pd from graphiti_core import Graphiti from graphiti_core.graphiti import AddEpisodeResults +from graphiti_core.helpers import semaphore_gather from graphiti_core.llm_client import LLMConfig, OpenAIClient from graphiti_core.nodes import EpisodeType from graphiti_core.prompts import prompt_library from graphiti_core.prompts.eval import EvalAddEpisodeResults -from graphiti_core.utils.maintenance import clear_data from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER +async def build_subgraph( + graphiti: Graphiti, + user_id: str, + multi_session, + multi_session_dates, + session_length: int, + group_id_suffix: str, +) -> tuple[str, list[AddEpisodeResults], list[str]]: + add_episode_results: list[AddEpisodeResults] = [] + add_episode_context: list[str] = [] + + message_count = 0 + for session_idx, session in enumerate(multi_session): + for _, msg in enumerate(session): + if message_count >= session_length: + continue + message_count += 1 + date = multi_session_dates[session_idx] + ' UTC' + date_format = '%Y/%m/%d (%a) %H:%M UTC' + date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc) + + episode_body = f'{msg["role"]}: {msg["content"]}' + results = await graphiti.add_episode( + name='', + episode_body=episode_body, + reference_time=date_string, + source=EpisodeType.message, + source_description='', + group_id=user_id + '_' + group_id_suffix, + ) + for node in results.nodes: + node.name_embedding = None + for edge in results.edges: + edge.fact_embedding = None + + add_episode_results.append(results) + add_episode_context.append(msg['content']) + + return user_id, add_episode_results, add_episode_context + + async def build_graph( - group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti + group_id_suffix: str, multi_session_count: int, session_length: int, graphiti: Graphiti ) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]: # Get longmemeval dataset lme_dataset_option = ( @@ -40,51 +81,35 @@ async def build_graph( add_episode_results: dict[str, list[AddEpisodeResults]] = {} add_episode_context: dict[str, list[str]] = {} - for multi_session_idx in multi_session: - multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx] - multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx] + subgraph_results: list[tuple[str, list[AddEpisodeResults], list[str]]] = await semaphore_gather( + *[ + build_subgraph( + graphiti, + user_id='lme_oracle_experiment_user_' + str(multi_session_idx), + multi_session=lme_dataset_df['haystack_sessions'].iloc[multi_session_idx], + multi_session_dates=lme_dataset_df['haystack_dates'].iloc[multi_session_idx], + session_length=session_length, + group_id_suffix=group_id_suffix, + ) + for multi_session_idx in range(multi_session_count) + ] + ) - user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx) - await clear_data(graphiti.driver, [user_id]) + for user_id, episode_results, episode_context in subgraph_results: + add_episode_results[user_id] = episode_results + add_episode_context[user_id] = episode_context - add_episode_results[user_id] = [] - add_episode_context[user_id] = [] - - message_count = 0 - for session_idx, session in enumerate(multi_session): - for _, msg in enumerate(session): - if message_count >= session_length: - continue - message_count += 1 - date = multi_session_dates[session_idx] + ' UTC' - date_format = '%Y/%m/%d (%a) %H:%M UTC' - date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc) - - episode_body = f'{msg["role"]}: {msg["content"]}' - results = await graphiti.add_episode( - name='', - episode_body=episode_body, - reference_time=date_string, - source=EpisodeType.message, - source_description='', - group_id=user_id + '_' + group_id_suffix, - ) - for node in results.nodes: - node.name_embedding = None - for edge in results.edges: - edge.fact_embedding = None - - add_episode_results[user_id].append(results) - add_episode_context[user_id].append(msg['content']) return add_episode_results, add_episode_context -async def build_baseline_graph(multi_session: list[int], session_length: int): - # Use gpt-4o for graph building baseline - llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o')) +async def build_baseline_graph(multi_session_count: int, session_length: int): + # Use gpt-4.1-mini for graph building baseline + llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini')) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) - add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti) + add_episode_results, _ = await build_graph( + 'baseline', multi_session_count, session_length, graphiti + ) filename = 'baseline_graph_results.json' @@ -97,7 +122,7 @@ async def build_baseline_graph(multi_session: list[int], session_length: int): json.dump(serializable_baseline_graph_results, file, indent=4, default=str) -async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float: +async def eval_graph(multi_session_count: int, session_length: int, llm_client=None) -> float: if llm_client is None: llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini')) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) @@ -109,7 +134,7 @@ async def eval_graph(multi_session: list[int], session_length: int, llm_client=N for key, value in baseline_results_raw.items() } add_episode_results, add_episode_context = await build_graph( - 'candidate', multi_session, session_length, graphiti + 'candidate', multi_session_count, session_length, graphiti ) filename = 'candidate_graph_results.json'