Eval updates (#404)
* update eval * make format * remove unused imports * mypy
This commit is contained in:
parent
0b94e0e603
commit
7ee4e38616
3 changed files with 78 additions and 52 deletions
|
|
@ -134,14 +134,14 @@ async def retrieve_episodes(
|
||||||
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
||||||
"""
|
"""
|
||||||
group_id_filter: LiteralString = (
|
group_id_filter: LiteralString = (
|
||||||
'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
'\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||||
)
|
)
|
||||||
source_filter: LiteralString = 'AND e.source = $source' if source is not None else ''
|
source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||||
"""
|
"""
|
||||||
+ group_id_filter
|
+ group_id_filter
|
||||||
+ source_filter
|
+ source_filter
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -161,7 +161,7 @@ async def retrieve_episodes(
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
reference_time=reference_time,
|
reference_time=reference_time,
|
||||||
source=source,
|
source=source.name if source is not None else None,
|
||||||
num_episodes=last_n,
|
num_episodes=last_n,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,10 @@ async def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--multi-session',
|
'--multi-session-count',
|
||||||
type=int,
|
type=int,
|
||||||
nargs='+',
|
|
||||||
required=True,
|
required=True,
|
||||||
help='List of integers representing multi-session values (e.g., 1 2 3)',
|
help='Integer representing multi-session count',
|
||||||
)
|
)
|
||||||
parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
|
parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -27,11 +26,13 @@ async def main():
|
||||||
if args.build_baseline:
|
if args.build_baseline:
|
||||||
print('Running build_baseline_graph...')
|
print('Running build_baseline_graph...')
|
||||||
await build_baseline_graph(
|
await build_baseline_graph(
|
||||||
multi_session=args.multi_session, session_length=args.session_length
|
multi_session_count=args.multi_session_count, session_length=args.session_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# Always call eval_graph
|
# Always call eval_graph
|
||||||
result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length)
|
result = await eval_graph(
|
||||||
|
multi_session_count=args.multi_session_count, session_length=args.session_length
|
||||||
|
)
|
||||||
print('Result of eval_graph:', result)
|
print('Result of eval_graph:', result)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,16 +21,57 @@ import pandas as pd
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.graphiti import AddEpisodeResults
|
from graphiti_core.graphiti import AddEpisodeResults
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||||
from graphiti_core.nodes import EpisodeType
|
from graphiti_core.nodes import EpisodeType
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.eval import EvalAddEpisodeResults
|
from graphiti_core.prompts.eval import EvalAddEpisodeResults
|
||||||
from graphiti_core.utils.maintenance import clear_data
|
|
||||||
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
|
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
|
||||||
|
|
||||||
|
|
||||||
|
async def build_subgraph(
|
||||||
|
graphiti: Graphiti,
|
||||||
|
user_id: str,
|
||||||
|
multi_session,
|
||||||
|
multi_session_dates,
|
||||||
|
session_length: int,
|
||||||
|
group_id_suffix: str,
|
||||||
|
) -> tuple[str, list[AddEpisodeResults], list[str]]:
|
||||||
|
add_episode_results: list[AddEpisodeResults] = []
|
||||||
|
add_episode_context: list[str] = []
|
||||||
|
|
||||||
|
message_count = 0
|
||||||
|
for session_idx, session in enumerate(multi_session):
|
||||||
|
for _, msg in enumerate(session):
|
||||||
|
if message_count >= session_length:
|
||||||
|
continue
|
||||||
|
message_count += 1
|
||||||
|
date = multi_session_dates[session_idx] + ' UTC'
|
||||||
|
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
||||||
|
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
episode_body = f'{msg["role"]}: {msg["content"]}'
|
||||||
|
results = await graphiti.add_episode(
|
||||||
|
name='',
|
||||||
|
episode_body=episode_body,
|
||||||
|
reference_time=date_string,
|
||||||
|
source=EpisodeType.message,
|
||||||
|
source_description='',
|
||||||
|
group_id=user_id + '_' + group_id_suffix,
|
||||||
|
)
|
||||||
|
for node in results.nodes:
|
||||||
|
node.name_embedding = None
|
||||||
|
for edge in results.edges:
|
||||||
|
edge.fact_embedding = None
|
||||||
|
|
||||||
|
add_episode_results.append(results)
|
||||||
|
add_episode_context.append(msg['content'])
|
||||||
|
|
||||||
|
return user_id, add_episode_results, add_episode_context
|
||||||
|
|
||||||
|
|
||||||
async def build_graph(
|
async def build_graph(
|
||||||
group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti
|
group_id_suffix: str, multi_session_count: int, session_length: int, graphiti: Graphiti
|
||||||
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
|
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
|
||||||
# Get longmemeval dataset
|
# Get longmemeval dataset
|
||||||
lme_dataset_option = (
|
lme_dataset_option = (
|
||||||
|
|
@ -40,51 +81,35 @@ async def build_graph(
|
||||||
|
|
||||||
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
|
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
|
||||||
add_episode_context: dict[str, list[str]] = {}
|
add_episode_context: dict[str, list[str]] = {}
|
||||||
for multi_session_idx in multi_session:
|
subgraph_results: list[tuple[str, list[AddEpisodeResults], list[str]]] = await semaphore_gather(
|
||||||
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
|
*[
|
||||||
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
|
build_subgraph(
|
||||||
|
graphiti,
|
||||||
|
user_id='lme_oracle_experiment_user_' + str(multi_session_idx),
|
||||||
|
multi_session=lme_dataset_df['haystack_sessions'].iloc[multi_session_idx],
|
||||||
|
multi_session_dates=lme_dataset_df['haystack_dates'].iloc[multi_session_idx],
|
||||||
|
session_length=session_length,
|
||||||
|
group_id_suffix=group_id_suffix,
|
||||||
|
)
|
||||||
|
for multi_session_idx in range(multi_session_count)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
|
for user_id, episode_results, episode_context in subgraph_results:
|
||||||
await clear_data(graphiti.driver, [user_id])
|
add_episode_results[user_id] = episode_results
|
||||||
|
add_episode_context[user_id] = episode_context
|
||||||
|
|
||||||
add_episode_results[user_id] = []
|
|
||||||
add_episode_context[user_id] = []
|
|
||||||
|
|
||||||
message_count = 0
|
|
||||||
for session_idx, session in enumerate(multi_session):
|
|
||||||
for _, msg in enumerate(session):
|
|
||||||
if message_count >= session_length:
|
|
||||||
continue
|
|
||||||
message_count += 1
|
|
||||||
date = multi_session_dates[session_idx] + ' UTC'
|
|
||||||
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
|
||||||
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
|
||||||
|
|
||||||
episode_body = f'{msg["role"]}: {msg["content"]}'
|
|
||||||
results = await graphiti.add_episode(
|
|
||||||
name='',
|
|
||||||
episode_body=episode_body,
|
|
||||||
reference_time=date_string,
|
|
||||||
source=EpisodeType.message,
|
|
||||||
source_description='',
|
|
||||||
group_id=user_id + '_' + group_id_suffix,
|
|
||||||
)
|
|
||||||
for node in results.nodes:
|
|
||||||
node.name_embedding = None
|
|
||||||
for edge in results.edges:
|
|
||||||
edge.fact_embedding = None
|
|
||||||
|
|
||||||
add_episode_results[user_id].append(results)
|
|
||||||
add_episode_context[user_id].append(msg['content'])
|
|
||||||
return add_episode_results, add_episode_context
|
return add_episode_results, add_episode_context
|
||||||
|
|
||||||
|
|
||||||
async def build_baseline_graph(multi_session: list[int], session_length: int):
|
async def build_baseline_graph(multi_session_count: int, session_length: int):
|
||||||
# Use gpt-4o for graph building baseline
|
# Use gpt-4.1-mini for graph building baseline
|
||||||
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o'))
|
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||||
|
|
||||||
add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti)
|
add_episode_results, _ = await build_graph(
|
||||||
|
'baseline', multi_session_count, session_length, graphiti
|
||||||
|
)
|
||||||
|
|
||||||
filename = 'baseline_graph_results.json'
|
filename = 'baseline_graph_results.json'
|
||||||
|
|
||||||
|
|
@ -97,7 +122,7 @@ async def build_baseline_graph(multi_session: list[int], session_length: int):
|
||||||
json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
|
json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
|
||||||
|
|
||||||
|
|
||||||
async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float:
|
async def eval_graph(multi_session_count: int, session_length: int, llm_client=None) -> float:
|
||||||
if llm_client is None:
|
if llm_client is None:
|
||||||
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
|
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||||
|
|
@ -109,7 +134,7 @@ async def eval_graph(multi_session: list[int], session_length: int, llm_client=N
|
||||||
for key, value in baseline_results_raw.items()
|
for key, value in baseline_results_raw.items()
|
||||||
}
|
}
|
||||||
add_episode_results, add_episode_context = await build_graph(
|
add_episode_results, add_episode_context = await build_graph(
|
||||||
'candidate', multi_session, session_length, graphiti
|
'candidate', multi_session_count, session_length, graphiti
|
||||||
)
|
)
|
||||||
|
|
||||||
filename = 'candidate_graph_results.json'
|
filename = 'candidate_graph_results.json'
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue