196 lines
7.9 KiB
Python
196 lines
7.9 KiB
Python
import json
|
|
from datetime import datetime, timedelta
|
|
|
|
import pandas as pd
|
|
|
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
|
|
|
|
def create_episodes_from_messages(input_message, input_previous_messages):
|
|
"""
|
|
Create an episode and a list of previous episodes from input messages.
|
|
"""
|
|
# Current time for the episode
|
|
current_time = datetime.now()
|
|
|
|
# Create the current episode
|
|
role = input_message['role']
|
|
content = input_message['content']
|
|
message_content = f'{role}: {content}'
|
|
episode = EpisodicNode(
|
|
name='',
|
|
group_id='',
|
|
source=EpisodeType.message,
|
|
type=EpisodeType.message,
|
|
source_description='',
|
|
content=message_content,
|
|
valid_at=current_time,
|
|
)
|
|
|
|
# Create previous episodes
|
|
num_previous_messages = len(input_previous_messages)
|
|
previous_times = [
|
|
current_time - timedelta(minutes=num_previous_messages - i)
|
|
for i in range(num_previous_messages)
|
|
]
|
|
previous_episodes = [
|
|
EpisodicNode(
|
|
name='',
|
|
group_id='',
|
|
source=EpisodeType.message,
|
|
source_description='',
|
|
content=f"{message['role']}: {message['content']}",
|
|
valid_at=previous_time,
|
|
)
|
|
for message, previous_time in zip(input_previous_messages, previous_times)
|
|
]
|
|
|
|
return episode, previous_episodes
|
|
|
|
|
|
async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
|
# Import necessary functions
|
|
from graphiti_core.utils.maintenance.edge_operations import extract_edges
|
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
extract_nodes,
|
|
resolve_extracted_nodes,
|
|
)
|
|
|
|
# Loop through each unique message_index_within_snippet in sorted order
|
|
for message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
|
|
message_df = snippet_df[snippet_df['message_index_within_snippet'] == message_index]
|
|
|
|
#### Process 'extract_nodes' task
|
|
extract_nodes_row = message_df[message_df['task_name'] == 'extract_nodes']
|
|
assert (
|
|
len(extract_nodes_row) == 1
|
|
), f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
|
|
input_message = json.loads(extract_nodes_row.iloc[0]['input_message'])
|
|
input_previous_messages = json.loads(extract_nodes_row.iloc[0]['input_previous_messages'])
|
|
episode, previous_episodes = create_episodes_from_messages(
|
|
input_message, input_previous_messages
|
|
)
|
|
extracted_nodes = await extract_nodes(llm_client, episode, previous_episodes)
|
|
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps(
|
|
[entity_to_dict(node) for node in extracted_nodes]
|
|
)
|
|
|
|
#### Process 'dedupe_nodes' task
|
|
dedupe_nodes_row = message_df[message_df['task_name'] == 'dedupe_nodes']
|
|
assert (
|
|
len(dedupe_nodes_row) == 1
|
|
), "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"
|
|
|
|
# Calculate existing nodes list
|
|
existing_nodes = []
|
|
for prev_message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
|
|
if prev_message_index >= message_index:
|
|
break
|
|
|
|
# Filter for previous messages with 'extract_nodes' task
|
|
prev_message_df = snippet_df[
|
|
(snippet_df['message_index_within_snippet'] == prev_message_index)
|
|
& (snippet_df['task_name'] == 'extract_nodes')
|
|
]
|
|
|
|
# Retrieve and deserialize the nodes
|
|
serialized_nodes = prev_message_df.iloc[0][output_column_name]
|
|
node_dicts = json.loads(serialized_nodes)
|
|
nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in node_dicts]
|
|
existing_nodes.extend(nodes)
|
|
|
|
existing_nodes_lists = [existing_nodes for _ in range(len(extracted_nodes))]
|
|
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
|
llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes
|
|
)
|
|
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps(
|
|
[entity_to_dict(node) for node in resolved_nodes]
|
|
)
|
|
|
|
#### Process 'extract_edges' task
|
|
extract_edges_row = message_df[message_df['task_name'] == 'extract_edges']
|
|
assert (
|
|
len(extract_edges_row) == 1
|
|
), f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
|
|
extracted_edges = await extract_edges(
|
|
llm_client,
|
|
episode,
|
|
extracted_nodes,
|
|
previous_episodes,
|
|
group_id='',
|
|
)
|
|
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps(
|
|
[entity_to_dict(edge) for edge in extracted_edges]
|
|
)
|
|
|
|
########## TODO: Complete the implementation of the below
|
|
|
|
#### Process 'dedupe_edges' task
|
|
# dedupe_edges_row = message_df[message_df['task_name'] == 'dedupe_edges']
|
|
# assert len(dedupe_edges_row) == 1, "There should be exactly one row for 'dedupe_edges'"
|
|
# output = dedupe_extracted_edge(
|
|
# llm_client,
|
|
# extracted_edge,
|
|
# related_edges,
|
|
# )
|
|
# snippet_df.at[dedupe_edges_row.index[0], output_column_name] = output
|
|
|
|
#### Process 'extract_edge_dates' task
|
|
# extract_edge_dates_row = message_df[message_df['task_name'] == 'extract_edge_dates']
|
|
# assert len(extract_edge_dates_row) == 1, "There should be exactly one row for 'extract_edge_dates'"
|
|
# output = extract_edge_dates(extract_edge_dates_row.iloc[0]['input_extracted_edge_dates'])
|
|
# snippet_df.at[extract_edge_dates_row.index[0], output_column_name] = output
|
|
|
|
#### Process 'edge_invalidation' task
|
|
# edge_invalidation_row = message_df[message_df['task_name'] == 'edge_invalidation']
|
|
# assert len(edge_invalidation_row) == 1, "There should be exactly one row for 'edge_invalidation'"
|
|
# output = edge_invalidation(edge_invalidation_row.iloc[0]['input_edge_invalidation'])
|
|
# snippet_df.at[edge_invalidation_row.index[0], output_column_name] = output
|
|
|
|
return snippet_df
|
|
|
|
|
|
async def ingest_and_label_minidataset(llm_client, minidataset_df, output_column_name):
|
|
# Add a new column with the specified name, initialized with empty values
|
|
minidataset_df[output_column_name] = None
|
|
|
|
minidataset_labelled_df = None
|
|
for snippet_index in sorted(minidataset_df['snippet_index'].unique()):
|
|
snippet_df = minidataset_df[minidataset_df['snippet_index'] == snippet_index]
|
|
|
|
# Pass the output column name to the ingest_and_label_snippet function
|
|
snippet_df_labelled = await ingest_and_label_snippet(
|
|
llm_client, snippet_df, output_column_name
|
|
)
|
|
|
|
if minidataset_labelled_df is None:
|
|
minidataset_labelled_df = snippet_df_labelled
|
|
else:
|
|
minidataset_labelled_df = pd.concat([minidataset_labelled_df, snippet_df_labelled])
|
|
|
|
return minidataset_labelled_df
|
|
|
|
|
|
def entity_to_dict(entity):
|
|
"""
|
|
Convert an entity object to a dictionary, handling datetime serialization.
|
|
"""
|
|
entity_dict = vars(entity)
|
|
for key, value in entity_dict.items():
|
|
if isinstance(value, datetime):
|
|
entity_dict[key] = value.isoformat() # Convert datetime to ISO 8601 string
|
|
return entity_dict
|
|
|
|
|
|
def dict_to_entity(entity_dict, entity_class):
|
|
"""
|
|
Convert a dictionary back to an entity object, handling datetime deserialization.
|
|
"""
|
|
for key, value in entity_dict.items():
|
|
try:
|
|
# Attempt to parse strings back to datetime objects
|
|
entity_dict[key] = datetime.fromisoformat(value)
|
|
except (ValueError, TypeError):
|
|
# If parsing fails, keep the original value
|
|
pass
|
|
return entity_class(**entity_dict)
|