Eval updates (#404)
* update eval * make format * remove unused imports * mypy
This commit is contained in:
parent
0b94e0e603
commit
7ee4e38616
3 changed files with 78 additions and 52 deletions
|
|
@ -134,14 +134,14 @@ async def retrieve_episodes(
|
|||
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
||||
"""
|
||||
group_id_filter: LiteralString = (
|
||||
'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||
'\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||
)
|
||||
source_filter: LiteralString = 'AND e.source = $source' if source is not None else ''
|
||||
source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ group_id_filter
|
||||
+ source_filter
|
||||
+ """
|
||||
|
|
@ -161,7 +161,7 @@ async def retrieve_episodes(
|
|||
result = await driver.execute_query(
|
||||
query,
|
||||
reference_time=reference_time,
|
||||
source=source,
|
||||
source=source.name if source is not None else None,
|
||||
num_episodes=last_n,
|
||||
group_ids=group_ids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
|
|
|
|||
|
|
@ -10,11 +10,10 @@ async def main():
|
|||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--multi-session',
|
||||
'--multi-session-count',
|
||||
type=int,
|
||||
nargs='+',
|
||||
required=True,
|
||||
help='List of integers representing multi-session values (e.g., 1 2 3)',
|
||||
help='Integer representing multi-session count',
|
||||
)
|
||||
parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
|
||||
parser.add_argument(
|
||||
|
|
@ -27,11 +26,13 @@ async def main():
|
|||
if args.build_baseline:
|
||||
print('Running build_baseline_graph...')
|
||||
await build_baseline_graph(
|
||||
multi_session=args.multi_session, session_length=args.session_length
|
||||
multi_session_count=args.multi_session_count, session_length=args.session_length
|
||||
)
|
||||
|
||||
# Always call eval_graph
|
||||
result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length)
|
||||
result = await eval_graph(
|
||||
multi_session_count=args.multi_session_count, session_length=args.session_length
|
||||
)
|
||||
print('Result of eval_graph:', result)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,16 +21,57 @@ import pandas as pd
|
|||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.graphiti import AddEpisodeResults
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.eval import EvalAddEpisodeResults
|
||||
from graphiti_core.utils.maintenance import clear_data
|
||||
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
|
||||
|
||||
|
||||
async def build_subgraph(
|
||||
graphiti: Graphiti,
|
||||
user_id: str,
|
||||
multi_session,
|
||||
multi_session_dates,
|
||||
session_length: int,
|
||||
group_id_suffix: str,
|
||||
) -> tuple[str, list[AddEpisodeResults], list[str]]:
|
||||
add_episode_results: list[AddEpisodeResults] = []
|
||||
add_episode_context: list[str] = []
|
||||
|
||||
message_count = 0
|
||||
for session_idx, session in enumerate(multi_session):
|
||||
for _, msg in enumerate(session):
|
||||
if message_count >= session_length:
|
||||
continue
|
||||
message_count += 1
|
||||
date = multi_session_dates[session_idx] + ' UTC'
|
||||
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
||||
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
||||
|
||||
episode_body = f'{msg["role"]}: {msg["content"]}'
|
||||
results = await graphiti.add_episode(
|
||||
name='',
|
||||
episode_body=episode_body,
|
||||
reference_time=date_string,
|
||||
source=EpisodeType.message,
|
||||
source_description='',
|
||||
group_id=user_id + '_' + group_id_suffix,
|
||||
)
|
||||
for node in results.nodes:
|
||||
node.name_embedding = None
|
||||
for edge in results.edges:
|
||||
edge.fact_embedding = None
|
||||
|
||||
add_episode_results.append(results)
|
||||
add_episode_context.append(msg['content'])
|
||||
|
||||
return user_id, add_episode_results, add_episode_context
|
||||
|
||||
|
||||
async def build_graph(
|
||||
group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti
|
||||
group_id_suffix: str, multi_session_count: int, session_length: int, graphiti: Graphiti
|
||||
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
|
||||
# Get longmemeval dataset
|
||||
lme_dataset_option = (
|
||||
|
|
@ -40,51 +81,35 @@ async def build_graph(
|
|||
|
||||
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
|
||||
add_episode_context: dict[str, list[str]] = {}
|
||||
for multi_session_idx in multi_session:
|
||||
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
|
||||
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
|
||||
subgraph_results: list[tuple[str, list[AddEpisodeResults], list[str]]] = await semaphore_gather(
|
||||
*[
|
||||
build_subgraph(
|
||||
graphiti,
|
||||
user_id='lme_oracle_experiment_user_' + str(multi_session_idx),
|
||||
multi_session=lme_dataset_df['haystack_sessions'].iloc[multi_session_idx],
|
||||
multi_session_dates=lme_dataset_df['haystack_dates'].iloc[multi_session_idx],
|
||||
session_length=session_length,
|
||||
group_id_suffix=group_id_suffix,
|
||||
)
|
||||
for multi_session_idx in range(multi_session_count)
|
||||
]
|
||||
)
|
||||
|
||||
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
|
||||
await clear_data(graphiti.driver, [user_id])
|
||||
for user_id, episode_results, episode_context in subgraph_results:
|
||||
add_episode_results[user_id] = episode_results
|
||||
add_episode_context[user_id] = episode_context
|
||||
|
||||
add_episode_results[user_id] = []
|
||||
add_episode_context[user_id] = []
|
||||
|
||||
message_count = 0
|
||||
for session_idx, session in enumerate(multi_session):
|
||||
for _, msg in enumerate(session):
|
||||
if message_count >= session_length:
|
||||
continue
|
||||
message_count += 1
|
||||
date = multi_session_dates[session_idx] + ' UTC'
|
||||
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
||||
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
||||
|
||||
episode_body = f'{msg["role"]}: {msg["content"]}'
|
||||
results = await graphiti.add_episode(
|
||||
name='',
|
||||
episode_body=episode_body,
|
||||
reference_time=date_string,
|
||||
source=EpisodeType.message,
|
||||
source_description='',
|
||||
group_id=user_id + '_' + group_id_suffix,
|
||||
)
|
||||
for node in results.nodes:
|
||||
node.name_embedding = None
|
||||
for edge in results.edges:
|
||||
edge.fact_embedding = None
|
||||
|
||||
add_episode_results[user_id].append(results)
|
||||
add_episode_context[user_id].append(msg['content'])
|
||||
return add_episode_results, add_episode_context
|
||||
|
||||
|
||||
async def build_baseline_graph(multi_session: list[int], session_length: int):
|
||||
# Use gpt-4o for graph building baseline
|
||||
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o'))
|
||||
async def build_baseline_graph(multi_session_count: int, session_length: int):
|
||||
# Use gpt-4.1-mini for graph building baseline
|
||||
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||
|
||||
add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti)
|
||||
add_episode_results, _ = await build_graph(
|
||||
'baseline', multi_session_count, session_length, graphiti
|
||||
)
|
||||
|
||||
filename = 'baseline_graph_results.json'
|
||||
|
||||
|
|
@ -97,7 +122,7 @@ async def build_baseline_graph(multi_session: list[int], session_length: int):
|
|||
json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
|
||||
|
||||
|
||||
async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float:
|
||||
async def eval_graph(multi_session_count: int, session_length: int, llm_client=None) -> float:
|
||||
if llm_client is None:
|
||||
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||
|
|
@ -109,7 +134,7 @@ async def eval_graph(multi_session: list[int], session_length: int, llm_client=N
|
|||
for key, value in baseline_results_raw.items()
|
||||
}
|
||||
add_episode_results, add_episode_context = await build_graph(
|
||||
'candidate', multi_session, session_length, graphiti
|
||||
'candidate', multi_session_count, session_length, graphiti
|
||||
)
|
||||
|
||||
filename = 'candidate_graph_results.json'
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue