Eval updates (#404)

* update eval

* make format

* remove unused imports

* mypy
This commit is contained in:
Preston Rasmussen 2025-04-27 14:27:47 -04:00 committed by GitHub
parent 0b94e0e603
commit 7ee4e38616
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 78 additions and 52 deletions

View file

@ -134,14 +134,14 @@ async def retrieve_episodes(
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
""" """
group_id_filter: LiteralString = ( group_id_filter: LiteralString = (
'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' '\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
) )
source_filter: LiteralString = 'AND e.source = $source' if source is not None else '' source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
""" """
+ group_id_filter + group_id_filter
+ source_filter + source_filter
+ """ + """
@ -161,7 +161,7 @@ async def retrieve_episodes(
result = await driver.execute_query( result = await driver.execute_query(
query, query,
reference_time=reference_time, reference_time=reference_time,
source=source, source=source.name if source is not None else None,
num_episodes=last_n, num_episodes=last_n,
group_ids=group_ids, group_ids=group_ids,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,

View file

@ -10,11 +10,10 @@ async def main():
) )
parser.add_argument( parser.add_argument(
'--multi-session', '--multi-session-count',
type=int, type=int,
nargs='+',
required=True, required=True,
help='List of integers representing multi-session values (e.g., 1 2 3)', help='Integer representing multi-session count',
) )
parser.add_argument('--session-length', type=int, required=True, help='Length of each session') parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
parser.add_argument( parser.add_argument(
@ -27,11 +26,13 @@ async def main():
if args.build_baseline: if args.build_baseline:
print('Running build_baseline_graph...') print('Running build_baseline_graph...')
await build_baseline_graph( await build_baseline_graph(
multi_session=args.multi_session, session_length=args.session_length multi_session_count=args.multi_session_count, session_length=args.session_length
) )
# Always call eval_graph # Always call eval_graph
result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length) result = await eval_graph(
multi_session_count=args.multi_session_count, session_length=args.session_length
)
print('Result of eval_graph:', result) print('Result of eval_graph:', result)

View file

@ -21,16 +21,57 @@ import pandas as pd
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.graphiti import AddEpisodeResults from graphiti_core.graphiti import AddEpisodeResults
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMConfig, OpenAIClient from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EpisodeType from graphiti_core.nodes import EpisodeType
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.eval import EvalAddEpisodeResults from graphiti_core.prompts.eval import EvalAddEpisodeResults
from graphiti_core.utils.maintenance import clear_data
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
async def build_subgraph(
graphiti: Graphiti,
user_id: str,
multi_session,
multi_session_dates,
session_length: int,
group_id_suffix: str,
) -> tuple[str, list[AddEpisodeResults], list[str]]:
add_episode_results: list[AddEpisodeResults] = []
add_episode_context: list[str] = []
message_count = 0
for session_idx, session in enumerate(multi_session):
for _, msg in enumerate(session):
if message_count >= session_length:
continue
message_count += 1
date = multi_session_dates[session_idx] + ' UTC'
date_format = '%Y/%m/%d (%a) %H:%M UTC'
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
episode_body = f'{msg["role"]}: {msg["content"]}'
results = await graphiti.add_episode(
name='',
episode_body=episode_body,
reference_time=date_string,
source=EpisodeType.message,
source_description='',
group_id=user_id + '_' + group_id_suffix,
)
for node in results.nodes:
node.name_embedding = None
for edge in results.edges:
edge.fact_embedding = None
add_episode_results.append(results)
add_episode_context.append(msg['content'])
return user_id, add_episode_results, add_episode_context
async def build_graph( async def build_graph(
group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti group_id_suffix: str, multi_session_count: int, session_length: int, graphiti: Graphiti
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]: ) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
# Get longmemeval dataset # Get longmemeval dataset
lme_dataset_option = ( lme_dataset_option = (
@ -40,51 +81,35 @@ async def build_graph(
add_episode_results: dict[str, list[AddEpisodeResults]] = {} add_episode_results: dict[str, list[AddEpisodeResults]] = {}
add_episode_context: dict[str, list[str]] = {} add_episode_context: dict[str, list[str]] = {}
for multi_session_idx in multi_session: subgraph_results: list[tuple[str, list[AddEpisodeResults], list[str]]] = await semaphore_gather(
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx] *[
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx] build_subgraph(
graphiti,
user_id='lme_oracle_experiment_user_' + str(multi_session_idx),
multi_session=lme_dataset_df['haystack_sessions'].iloc[multi_session_idx],
multi_session_dates=lme_dataset_df['haystack_dates'].iloc[multi_session_idx],
session_length=session_length,
group_id_suffix=group_id_suffix,
)
for multi_session_idx in range(multi_session_count)
]
)
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx) for user_id, episode_results, episode_context in subgraph_results:
await clear_data(graphiti.driver, [user_id]) add_episode_results[user_id] = episode_results
add_episode_context[user_id] = episode_context
add_episode_results[user_id] = []
add_episode_context[user_id] = []
message_count = 0
for session_idx, session in enumerate(multi_session):
for _, msg in enumerate(session):
if message_count >= session_length:
continue
message_count += 1
date = multi_session_dates[session_idx] + ' UTC'
date_format = '%Y/%m/%d (%a) %H:%M UTC'
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
episode_body = f'{msg["role"]}: {msg["content"]}'
results = await graphiti.add_episode(
name='',
episode_body=episode_body,
reference_time=date_string,
source=EpisodeType.message,
source_description='',
group_id=user_id + '_' + group_id_suffix,
)
for node in results.nodes:
node.name_embedding = None
for edge in results.edges:
edge.fact_embedding = None
add_episode_results[user_id].append(results)
add_episode_context[user_id].append(msg['content'])
return add_episode_results, add_episode_context return add_episode_results, add_episode_context
async def build_baseline_graph(multi_session: list[int], session_length: int): async def build_baseline_graph(multi_session_count: int, session_length: int):
# Use gpt-4o for graph building baseline # Use gpt-4.1-mini for graph building baseline
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o')) llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti) add_episode_results, _ = await build_graph(
'baseline', multi_session_count, session_length, graphiti
)
filename = 'baseline_graph_results.json' filename = 'baseline_graph_results.json'
@ -97,7 +122,7 @@ async def build_baseline_graph(multi_session: list[int], session_length: int):
json.dump(serializable_baseline_graph_results, file, indent=4, default=str) json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float: async def eval_graph(multi_session_count: int, session_length: int, llm_client=None) -> float:
if llm_client is None: if llm_client is None:
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini')) llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
@ -109,7 +134,7 @@ async def eval_graph(multi_session: list[int], session_length: int, llm_client=N
for key, value in baseline_results_raw.items() for key, value in baseline_results_raw.items()
} }
add_episode_results, add_episode_context = await build_graph( add_episode_results, add_episode_context = await build_graph(
'candidate', multi_session, session_length, graphiti 'candidate', multi_session_count, session_length, graphiti
) )
filename = 'candidate_graph_results.json' filename = 'candidate_graph_results.json'