graphiti/tests/evals/eval_extract_nodes.py
prestonrasmussen 948a0057fb dedupe updates
2025-04-07 11:23:20 -04:00

143 lines
4.6 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import csv # Add this import at the top of the file
import json
import os
from datetime import datetime, timedelta
import pytest
from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.llm_client import OpenAIClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.node_operations import extract_nodes
from tests.evals.utils import ingest_snippet, setup_logging
############# EVERYTHING BELOW IS OUTDATED
# Setup
load_dotenv()
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
logger = setup_logging()
async def general_extract_nodes_test(llm_client, data_sample):
episode = data_sample['episode']
previous_episodes = data_sample['previous_episodes']
gold_answer_names = data_sample['gold_answer_names']
hypothesis_nodes = await extract_nodes(llm_client, episode, previous_episodes)
hypothesis_node_names = [node.name for node in hypothesis_nodes]
# Sort both lists by node name
hypothesis_node_names.sort()
gold_answer_names.sort()
# assert hypothesis_node_names == gold_answer_names, \
# f"""Test Failed. Expected nodes: {gold_answer_names}. Got: {hypothesis_node_names}"""
return hypothesis_node_names
def prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx):
samples_csv_path = 'tests/evals/data/' + data_file_name + '.csv'
# From CSV path, load everything
with open(samples_csv_path, 'r') as file:
csv_reader = csv.DictReader(file)
lme_samples = list(csv_reader)
data_samples = []
# Loop through each row
for row in lme_samples:
### Prepare episode
current_time = datetime.now()
message = json.loads(row['message'])
role = message['role']
content = message['content']
message_content = role + ': ' + content
episode = EpisodicNode(
name='',
group_id='',
source=EpisodeType.message,
type=EpisodeType.message,
source_description='',
content=message_content,
valid_at=current_time,
)
### Prepare previous episodes
previous_messages = json.loads(row['previous_messages'])
num_previous_messages = len(previous_messages)
previous_times = [
current_time - timedelta(minutes=num_previous_messages - i)
for i in range(num_previous_messages)
]
previous_episodes = [
EpisodicNode(
name='',
group_id='',
source=EpisodeType.message,
source_description='',
content=message['role'] + ': ' + message['content'],
valid_at=previous_time,
)
for message, previous_time in zip(previous_messages, previous_times)
]
### TODO: Prepare gold answer names
### Add to data samples list
data_samples.append(
{
'episode': episode,
'previous_episodes': previous_episodes,
'gold_answer_names': [],
}
)
return data_samples
@pytest.mark.asyncio
async def test_extract_nodes():
model_name = 'gpt-4o-mini'
llm_config = LLMConfig(
api_key=os.getenv('OPENAI_API_KEY'),
model=model_name,
)
llm_client = OpenAIClient(config=llm_config)
data_file_name = 'output_short'
question_id = 'gpt4_2655b836'
session_idx = 0
message_idx = 0
data_samples = prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx)
for data_sample in data_samples:
print(f"\n\nEpisode: {data_sample['episode']}")
print('*' * 50)
print(f"Previous Episodes: {data_sample['previous_episodes']}")
print('*' * 50)
# print(f"Gold Answer Names: {gold_answer_names}")
await general_extract_nodes_test(llm_client, data_sample)