dedupe updates
This commit is contained in:
parent
a95c046cbb
commit
948a0057fb
4 changed files with 2773 additions and 3364 deletions
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,10 @@
|
|||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from graphiti_core.nodes import EpisodicNode, EpisodeType, EntityNode
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
|
||||
|
||||
def create_episodes_from_messages(input_message, input_previous_messages):
|
||||
"""
|
||||
|
|
@ -11,28 +14,31 @@ def create_episodes_from_messages(input_message, input_previous_messages):
|
|||
current_time = datetime.now()
|
||||
|
||||
# Create the current episode
|
||||
role = input_message["role"]
|
||||
content = input_message["content"]
|
||||
message_content = f"{role}: {content}"
|
||||
role = input_message['role']
|
||||
content = input_message['content']
|
||||
message_content = f'{role}: {content}'
|
||||
episode = EpisodicNode(
|
||||
name="",
|
||||
group_id="",
|
||||
name='',
|
||||
group_id='',
|
||||
source=EpisodeType.message,
|
||||
type=EpisodeType.message,
|
||||
source_description="",
|
||||
source_description='',
|
||||
content=message_content,
|
||||
valid_at=current_time,
|
||||
)
|
||||
|
||||
# Create previous episodes
|
||||
num_previous_messages = len(input_previous_messages)
|
||||
previous_times = [current_time - timedelta(minutes=num_previous_messages - i) for i in range(num_previous_messages)]
|
||||
previous_times = [
|
||||
current_time - timedelta(minutes=num_previous_messages - i)
|
||||
for i in range(num_previous_messages)
|
||||
]
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name="",
|
||||
group_id="",
|
||||
name='',
|
||||
group_id='',
|
||||
source=EpisodeType.message,
|
||||
source_description="",
|
||||
source_description='',
|
||||
content=f"{message['role']}: {message['content']}",
|
||||
valid_at=previous_time,
|
||||
)
|
||||
|
|
@ -41,10 +47,14 @@ def create_episodes_from_messages(input_message, input_previous_messages):
|
|||
|
||||
return episode, previous_episodes
|
||||
|
||||
|
||||
async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
||||
# Import necessary functions
|
||||
from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
|
||||
from graphiti_core.utils.maintenance.edge_operations import extract_edges
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
|
||||
# Loop through each unique message_index_within_snippet in sorted order
|
||||
for message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
|
||||
|
|
@ -52,16 +62,24 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
|||
|
||||
#### Process 'extract_nodes' task
|
||||
extract_nodes_row = message_df[message_df['task_name'] == 'extract_nodes']
|
||||
assert len(extract_nodes_row) == 1, f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
|
||||
assert (
|
||||
len(extract_nodes_row) == 1
|
||||
), f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
|
||||
input_message = json.loads(extract_nodes_row.iloc[0]['input_message'])
|
||||
input_previous_messages = json.loads(extract_nodes_row.iloc[0]['input_previous_messages'])
|
||||
episode, previous_episodes = create_episodes_from_messages(input_message, input_previous_messages)
|
||||
episode, previous_episodes = create_episodes_from_messages(
|
||||
input_message, input_previous_messages
|
||||
)
|
||||
extracted_nodes = await extract_nodes(llm_client, episode, previous_episodes)
|
||||
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in extracted_nodes])
|
||||
|
||||
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps(
|
||||
[entity_to_dict(node) for node in extracted_nodes]
|
||||
)
|
||||
|
||||
#### Process 'dedupe_nodes' task
|
||||
dedupe_nodes_row = message_df[message_df['task_name'] == 'dedupe_nodes']
|
||||
assert len(dedupe_nodes_row) == 1, "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"
|
||||
assert (
|
||||
len(dedupe_nodes_row) == 1
|
||||
), "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"
|
||||
|
||||
# Calculate existing nodes list
|
||||
existing_nodes = []
|
||||
|
|
@ -71,8 +89,8 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
|||
|
||||
# Filter for previous messages with 'extract_nodes' task
|
||||
prev_message_df = snippet_df[
|
||||
(snippet_df['message_index_within_snippet'] == prev_message_index) &
|
||||
(snippet_df['task_name'] == 'extract_nodes')
|
||||
(snippet_df['message_index_within_snippet'] == prev_message_index)
|
||||
& (snippet_df['task_name'] == 'extract_nodes')
|
||||
]
|
||||
|
||||
# Retrieve and deserialize the nodes
|
||||
|
|
@ -82,12 +100,18 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
|||
existing_nodes.extend(nodes)
|
||||
|
||||
existing_nodes_lists = [existing_nodes for _ in range(len(extracted_nodes))]
|
||||
resolved_nodes, uuid_map = await resolve_extracted_nodes(llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes)
|
||||
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in resolved_nodes])
|
||||
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
||||
llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes
|
||||
)
|
||||
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps(
|
||||
[entity_to_dict(node) for node in resolved_nodes]
|
||||
)
|
||||
|
||||
#### Process 'extract_edges' task
|
||||
extract_edges_row = message_df[message_df['task_name'] == 'extract_edges']
|
||||
assert len(extract_edges_row) == 1, f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
|
||||
assert (
|
||||
len(extract_edges_row) == 1
|
||||
), f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
|
||||
extracted_edges = await extract_edges(
|
||||
llm_client,
|
||||
episode,
|
||||
|
|
@ -95,7 +119,9 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
|||
previous_episodes,
|
||||
group_id='',
|
||||
)
|
||||
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps([entity_to_dict(edge) for edge in extracted_edges])
|
||||
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps(
|
||||
[entity_to_dict(edge) for edge in extracted_edges]
|
||||
)
|
||||
|
||||
########## TODO: Complete the implementation of the below
|
||||
|
||||
|
|
@ -131,17 +157,20 @@ async def ingest_and_label_minidataset(llm_client, minidataset_df, output_column
|
|||
minidataset_labelled_df = None
|
||||
for snippet_index in sorted(minidataset_df['snippet_index'].unique()):
|
||||
snippet_df = minidataset_df[minidataset_df['snippet_index'] == snippet_index]
|
||||
|
||||
|
||||
# Pass the output column name to the ingest_and_label_snippet function
|
||||
snippet_df_labelled = await ingest_and_label_snippet(llm_client, snippet_df, output_column_name)
|
||||
|
||||
snippet_df_labelled = await ingest_and_label_snippet(
|
||||
llm_client, snippet_df, output_column_name
|
||||
)
|
||||
|
||||
if minidataset_labelled_df is None:
|
||||
minidataset_labelled_df = snippet_df_labelled
|
||||
else:
|
||||
minidataset_labelled_df = pd.concat([minidataset_labelled_df, snippet_df_labelled])
|
||||
|
||||
|
||||
return minidataset_labelled_df
|
||||
|
||||
|
||||
def entity_to_dict(entity):
|
||||
"""
|
||||
Convert an entity object to a dictionary, handling datetime serialization.
|
||||
|
|
@ -152,6 +181,7 @@ def entity_to_dict(entity):
|
|||
entity_dict[key] = value.isoformat() # Convert datetime to ISO 8601 string
|
||||
return entity_dict
|
||||
|
||||
|
||||
def dict_to_entity(entity_dict, entity_class):
|
||||
"""
|
||||
Convert a dictionary back to an entity object, handling datetime deserialization.
|
||||
|
|
@ -163,4 +193,4 @@ def dict_to_entity(entity_dict, entity_class):
|
|||
except (ValueError, TypeError):
|
||||
# If parsing fails, keep the original value
|
||||
pass
|
||||
return entity_class(**entity_dict)
|
||||
return entity_class(**entity_dict)
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import csv # Add this import at the top of the file
|
||||
import json
|
||||
from tests.evals.utils import setup_logging, ingest_snippet
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
|
@ -24,17 +24,11 @@ from dotenv import load_dotenv
|
|||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
|
||||
from graphiti_core.utils.maintenance.node_operations import extract_nodes
|
||||
from graphiti_core.llm_client import OpenAIClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
|
||||
import csv # Add this import at the top of the file
|
||||
|
||||
|
||||
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.maintenance.node_operations import extract_nodes
|
||||
from tests.evals.utils import ingest_snippet, setup_logging
|
||||
|
||||
############# EVERYTHING BELOW IS OUTDATED
|
||||
|
||||
|
|
@ -63,68 +57,67 @@ async def general_extract_nodes_test(llm_client, data_sample):
|
|||
return hypothesis_node_names
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx):
|
||||
|
||||
samples_csv_path = "tests/evals/data/" + data_file_name + ".csv"
|
||||
samples_csv_path = 'tests/evals/data/' + data_file_name + '.csv'
|
||||
|
||||
# From CSV path, load everything
|
||||
with open(samples_csv_path, 'r') as file:
|
||||
csv_reader = csv.DictReader(file)
|
||||
lme_samples = list(csv_reader)
|
||||
|
||||
|
||||
data_samples = []
|
||||
|
||||
# Loop through each row
|
||||
for row in lme_samples:
|
||||
|
||||
### Prepare episode
|
||||
current_time = datetime.now()
|
||||
message = json.loads(row["message"])
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
message_content = role + ": " + content
|
||||
message = json.loads(row['message'])
|
||||
role = message['role']
|
||||
content = message['content']
|
||||
message_content = role + ': ' + content
|
||||
episode = EpisodicNode(
|
||||
name="",
|
||||
group_id="",
|
||||
name='',
|
||||
group_id='',
|
||||
source=EpisodeType.message,
|
||||
type=EpisodeType.message,
|
||||
source_description="",
|
||||
source_description='',
|
||||
content=message_content,
|
||||
valid_at=current_time,
|
||||
valid_at=current_time,
|
||||
)
|
||||
|
||||
### Prepare previous episodes
|
||||
previous_messages = json.loads(row["previous_messages"])
|
||||
previous_messages = json.loads(row['previous_messages'])
|
||||
num_previous_messages = len(previous_messages)
|
||||
previous_times = [current_time - timedelta(minutes=num_previous_messages-i) for i in range(num_previous_messages)]
|
||||
previous_episodes = [EpisodicNode(
|
||||
name="",
|
||||
group_id="",
|
||||
source=EpisodeType.message,
|
||||
source_description="",
|
||||
content=message["role"] + ": " + message["content"],
|
||||
valid_at=previous_time,
|
||||
) for message, previous_time in zip(previous_messages, previous_times)]
|
||||
previous_times = [
|
||||
current_time - timedelta(minutes=num_previous_messages - i)
|
||||
for i in range(num_previous_messages)
|
||||
]
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
name='',
|
||||
group_id='',
|
||||
source=EpisodeType.message,
|
||||
source_description='',
|
||||
content=message['role'] + ': ' + message['content'],
|
||||
valid_at=previous_time,
|
||||
)
|
||||
for message, previous_time in zip(previous_messages, previous_times)
|
||||
]
|
||||
|
||||
### TODO: Prepare gold answer names
|
||||
|
||||
### Add to data samples list
|
||||
data_samples.append({
|
||||
"episode": episode,
|
||||
"previous_episodes": previous_episodes,
|
||||
"gold_answer_names": [],
|
||||
})
|
||||
data_samples.append(
|
||||
{
|
||||
'episode': episode,
|
||||
'previous_episodes': previous_episodes,
|
||||
'gold_answer_names': [],
|
||||
}
|
||||
)
|
||||
|
||||
return data_samples
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_nodes():
|
||||
model_name = 'gpt-4o-mini'
|
||||
|
|
@ -135,17 +128,16 @@ async def test_extract_nodes():
|
|||
llm_client = OpenAIClient(config=llm_config)
|
||||
|
||||
data_file_name = 'output_short'
|
||||
question_id = "gpt4_2655b836"
|
||||
question_id = 'gpt4_2655b836'
|
||||
session_idx = 0
|
||||
message_idx = 0
|
||||
data_samples = prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx)
|
||||
|
||||
for data_sample in data_samples:
|
||||
print(f"\n\nEpisode: {data_sample['episode']}")
|
||||
print("*"*50)
|
||||
print('*' * 50)
|
||||
print(f"Previous Episodes: {data_sample['previous_episodes']}")
|
||||
print("*"*50)
|
||||
print('*' * 50)
|
||||
# print(f"Gold Answer Names: {gold_answer_names}")
|
||||
|
||||
await general_extract_nodes_test(llm_client, data_sample)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue