143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import csv # Add this import at the top of the file
|
|
import json
|
|
import os
|
|
from datetime import datetime, timedelta
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
|
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
from graphiti_core.graphiti import Graphiti
|
|
from graphiti_core.llm_client import OpenAIClient
|
|
from graphiti_core.llm_client.config import LLMConfig
|
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
from graphiti_core.utils.maintenance.node_operations import extract_nodes
|
|
from tests.evals.utils import ingest_snippet, setup_logging
|
|
|
|
############# EVERYTHING BELOW IS OUTDATED
|
|
|
|
# Setup
|
|
load_dotenv()
|
|
pytestmark = pytest.mark.integration
|
|
pytest_plugins = ('pytest_asyncio',)
|
|
logger = setup_logging()
|
|
|
|
|
|
async def general_extract_nodes_test(llm_client, data_sample):
|
|
episode = data_sample['episode']
|
|
previous_episodes = data_sample['previous_episodes']
|
|
gold_answer_names = data_sample['gold_answer_names']
|
|
|
|
hypothesis_nodes = await extract_nodes(llm_client, episode, previous_episodes)
|
|
hypothesis_node_names = [node.name for node in hypothesis_nodes]
|
|
|
|
# Sort both lists by node name
|
|
hypothesis_node_names.sort()
|
|
gold_answer_names.sort()
|
|
|
|
# assert hypothesis_node_names == gold_answer_names, \
|
|
# f"""Test Failed. Expected nodes: {gold_answer_names}. Got: {hypothesis_node_names}"""
|
|
|
|
return hypothesis_node_names
|
|
|
|
|
|
def prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx):
|
|
samples_csv_path = 'tests/evals/data/' + data_file_name + '.csv'
|
|
|
|
# From CSV path, load everything
|
|
with open(samples_csv_path, 'r') as file:
|
|
csv_reader = csv.DictReader(file)
|
|
lme_samples = list(csv_reader)
|
|
|
|
data_samples = []
|
|
|
|
# Loop through each row
|
|
for row in lme_samples:
|
|
### Prepare episode
|
|
current_time = datetime.now()
|
|
message = json.loads(row['message'])
|
|
role = message['role']
|
|
content = message['content']
|
|
message_content = role + ': ' + content
|
|
episode = EpisodicNode(
|
|
name='',
|
|
group_id='',
|
|
source=EpisodeType.message,
|
|
type=EpisodeType.message,
|
|
source_description='',
|
|
content=message_content,
|
|
valid_at=current_time,
|
|
)
|
|
|
|
### Prepare previous episodes
|
|
previous_messages = json.loads(row['previous_messages'])
|
|
num_previous_messages = len(previous_messages)
|
|
previous_times = [
|
|
current_time - timedelta(minutes=num_previous_messages - i)
|
|
for i in range(num_previous_messages)
|
|
]
|
|
previous_episodes = [
|
|
EpisodicNode(
|
|
name='',
|
|
group_id='',
|
|
source=EpisodeType.message,
|
|
source_description='',
|
|
content=message['role'] + ': ' + message['content'],
|
|
valid_at=previous_time,
|
|
)
|
|
for message, previous_time in zip(previous_messages, previous_times)
|
|
]
|
|
|
|
### TODO: Prepare gold answer names
|
|
|
|
### Add to data samples list
|
|
data_samples.append(
|
|
{
|
|
'episode': episode,
|
|
'previous_episodes': previous_episodes,
|
|
'gold_answer_names': [],
|
|
}
|
|
)
|
|
|
|
return data_samples
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_nodes():
|
|
model_name = 'gpt-4o-mini'
|
|
llm_config = LLMConfig(
|
|
api_key=os.getenv('OPENAI_API_KEY'),
|
|
model=model_name,
|
|
)
|
|
llm_client = OpenAIClient(config=llm_config)
|
|
|
|
data_file_name = 'output_short'
|
|
question_id = 'gpt4_2655b836'
|
|
session_idx = 0
|
|
message_idx = 0
|
|
data_samples = prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx)
|
|
|
|
for data_sample in data_samples:
|
|
print(f"\n\nEpisode: {data_sample['episode']}")
|
|
print('*' * 50)
|
|
print(f"Previous Episodes: {data_sample['previous_episodes']}")
|
|
print('*' * 50)
|
|
# print(f"Gold Answer Names: {gold_answer_names}")
|
|
|
|
await general_extract_nodes_test(llm_client, data_sample)
|