dedupe updates

This commit is contained in:
prestonrasmussen 2025-03-21 12:10:56 -04:00
parent a95c046cbb
commit 948a0057fb
4 changed files with 2773 additions and 3364 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,10 @@
import pandas as pd
from datetime import datetime, timedelta
import json
from graphiti_core.nodes import EpisodicNode, EpisodeType, EntityNode
from datetime import datetime, timedelta
import pandas as pd
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
def create_episodes_from_messages(input_message, input_previous_messages):
"""
@ -11,28 +14,31 @@ def create_episodes_from_messages(input_message, input_previous_messages):
current_time = datetime.now()
# Create the current episode
role = input_message["role"]
content = input_message["content"]
message_content = f"{role}: {content}"
role = input_message['role']
content = input_message['content']
message_content = f'{role}: {content}'
episode = EpisodicNode(
name="",
group_id="",
name='',
group_id='',
source=EpisodeType.message,
type=EpisodeType.message,
source_description="",
source_description='',
content=message_content,
valid_at=current_time,
)
# Create previous episodes
num_previous_messages = len(input_previous_messages)
previous_times = [current_time - timedelta(minutes=num_previous_messages - i) for i in range(num_previous_messages)]
previous_times = [
current_time - timedelta(minutes=num_previous_messages - i)
for i in range(num_previous_messages)
]
previous_episodes = [
EpisodicNode(
name="",
group_id="",
name='',
group_id='',
source=EpisodeType.message,
source_description="",
source_description='',
content=f"{message['role']}: {message['content']}",
valid_at=previous_time,
)
@ -41,10 +47,14 @@ def create_episodes_from_messages(input_message, input_previous_messages):
return episode, previous_episodes
async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
# Import necessary functions
from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
from graphiti_core.utils.maintenance.edge_operations import extract_edges
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
# Loop through each unique message_index_within_snippet in sorted order
for message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
@ -52,16 +62,24 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
#### Process 'extract_nodes' task
extract_nodes_row = message_df[message_df['task_name'] == 'extract_nodes']
assert len(extract_nodes_row) == 1, f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
assert (
len(extract_nodes_row) == 1
), f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
input_message = json.loads(extract_nodes_row.iloc[0]['input_message'])
input_previous_messages = json.loads(extract_nodes_row.iloc[0]['input_previous_messages'])
episode, previous_episodes = create_episodes_from_messages(input_message, input_previous_messages)
episode, previous_episodes = create_episodes_from_messages(
input_message, input_previous_messages
)
extracted_nodes = await extract_nodes(llm_client, episode, previous_episodes)
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in extracted_nodes])
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps(
[entity_to_dict(node) for node in extracted_nodes]
)
#### Process 'dedupe_nodes' task
dedupe_nodes_row = message_df[message_df['task_name'] == 'dedupe_nodes']
assert len(dedupe_nodes_row) == 1, "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"
assert (
len(dedupe_nodes_row) == 1
), "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"
# Calculate existing nodes list
existing_nodes = []
@ -71,8 +89,8 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
# Filter for previous messages with 'extract_nodes' task
prev_message_df = snippet_df[
(snippet_df['message_index_within_snippet'] == prev_message_index) &
(snippet_df['task_name'] == 'extract_nodes')
(snippet_df['message_index_within_snippet'] == prev_message_index)
& (snippet_df['task_name'] == 'extract_nodes')
]
# Retrieve and deserialize the nodes
@ -82,12 +100,18 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
existing_nodes.extend(nodes)
existing_nodes_lists = [existing_nodes for _ in range(len(extracted_nodes))]
resolved_nodes, uuid_map = await resolve_extracted_nodes(llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes)
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in resolved_nodes])
resolved_nodes, uuid_map = await resolve_extracted_nodes(
llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes
)
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps(
[entity_to_dict(node) for node in resolved_nodes]
)
#### Process 'extract_edges' task
extract_edges_row = message_df[message_df['task_name'] == 'extract_edges']
assert len(extract_edges_row) == 1, f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
assert (
len(extract_edges_row) == 1
), f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
extracted_edges = await extract_edges(
llm_client,
episode,
@ -95,7 +119,9 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
previous_episodes,
group_id='',
)
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps([entity_to_dict(edge) for edge in extracted_edges])
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps(
[entity_to_dict(edge) for edge in extracted_edges]
)
########## TODO: Complete the implementation of the below
@ -131,17 +157,20 @@ async def ingest_and_label_minidataset(llm_client, minidataset_df, output_column
minidataset_labelled_df = None
for snippet_index in sorted(minidataset_df['snippet_index'].unique()):
snippet_df = minidataset_df[minidataset_df['snippet_index'] == snippet_index]
# Pass the output column name to the ingest_and_label_snippet function
snippet_df_labelled = await ingest_and_label_snippet(llm_client, snippet_df, output_column_name)
snippet_df_labelled = await ingest_and_label_snippet(
llm_client, snippet_df, output_column_name
)
if minidataset_labelled_df is None:
minidataset_labelled_df = snippet_df_labelled
else:
minidataset_labelled_df = pd.concat([minidataset_labelled_df, snippet_df_labelled])
return minidataset_labelled_df
def entity_to_dict(entity):
"""
Convert an entity object to a dictionary, handling datetime serialization.
@ -152,6 +181,7 @@ def entity_to_dict(entity):
entity_dict[key] = value.isoformat() # Convert datetime to ISO 8601 string
return entity_dict
def dict_to_entity(entity_dict, entity_class):
"""
Convert a dictionary back to an entity object, handling datetime deserialization.
@ -163,4 +193,4 @@ def dict_to_entity(entity_dict, entity_class):
except (ValueError, TypeError):
# If parsing fails, keep the original value
pass
return entity_class(**entity_dict)
return entity_class(**entity_dict)

View file

@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import csv # Add this import at the top of the file
import json
from tests.evals.utils import setup_logging, ingest_snippet
import os
from datetime import datetime, timedelta
import pytest
@ -24,17 +24,11 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.utils.maintenance.node_operations import extract_nodes
from graphiti_core.llm_client import OpenAIClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.nodes import EpisodeType
import csv # Add this import at the top of the file
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.node_operations import extract_nodes
from tests.evals.utils import ingest_snippet, setup_logging
############# EVERYTHING BELOW IS OUTDATED
@ -63,68 +57,67 @@ async def general_extract_nodes_test(llm_client, data_sample):
return hypothesis_node_names
def prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx):
samples_csv_path = "tests/evals/data/" + data_file_name + ".csv"
samples_csv_path = 'tests/evals/data/' + data_file_name + '.csv'
# From CSV path, load everything
with open(samples_csv_path, 'r') as file:
csv_reader = csv.DictReader(file)
lme_samples = list(csv_reader)
data_samples = []
# Loop through each row
for row in lme_samples:
### Prepare episode
current_time = datetime.now()
message = json.loads(row["message"])
role = message["role"]
content = message["content"]
message_content = role + ": " + content
message = json.loads(row['message'])
role = message['role']
content = message['content']
message_content = role + ': ' + content
episode = EpisodicNode(
name="",
group_id="",
name='',
group_id='',
source=EpisodeType.message,
type=EpisodeType.message,
source_description="",
source_description='',
content=message_content,
valid_at=current_time,
valid_at=current_time,
)
### Prepare previous episodes
previous_messages = json.loads(row["previous_messages"])
previous_messages = json.loads(row['previous_messages'])
num_previous_messages = len(previous_messages)
previous_times = [current_time - timedelta(minutes=num_previous_messages-i) for i in range(num_previous_messages)]
previous_episodes = [EpisodicNode(
name="",
group_id="",
source=EpisodeType.message,
source_description="",
content=message["role"] + ": " + message["content"],
valid_at=previous_time,
) for message, previous_time in zip(previous_messages, previous_times)]
previous_times = [
current_time - timedelta(minutes=num_previous_messages - i)
for i in range(num_previous_messages)
]
previous_episodes = [
EpisodicNode(
name='',
group_id='',
source=EpisodeType.message,
source_description='',
content=message['role'] + ': ' + message['content'],
valid_at=previous_time,
)
for message, previous_time in zip(previous_messages, previous_times)
]
### TODO: Prepare gold answer names
### Add to data samples list
data_samples.append({
"episode": episode,
"previous_episodes": previous_episodes,
"gold_answer_names": [],
})
data_samples.append(
{
'episode': episode,
'previous_episodes': previous_episodes,
'gold_answer_names': [],
}
)
return data_samples
@pytest.mark.asyncio
async def test_extract_nodes():
model_name = 'gpt-4o-mini'
@ -135,17 +128,16 @@ async def test_extract_nodes():
llm_client = OpenAIClient(config=llm_config)
data_file_name = 'output_short'
question_id = "gpt4_2655b836"
question_id = 'gpt4_2655b836'
session_idx = 0
message_idx = 0
data_samples = prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx)
for data_sample in data_samples:
print(f"\n\nEpisode: {data_sample['episode']}")
print("*"*50)
print('*' * 50)
print(f"Previous Episodes: {data_sample['previous_episodes']}")
print("*"*50)
print('*' * 50)
# print(f"Gold Answer Names: {gold_answer_names}")
await general_extract_nodes_test(llm_client, data_sample)