dedupe updates
This commit is contained in:
parent
a95c046cbb
commit
948a0057fb
4 changed files with 2773 additions and 3364 deletions
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,10 @@
|
||||||
import pandas as pd
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
import json
|
||||||
from graphiti_core.nodes import EpisodicNode, EpisodeType, EntityNode
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
|
|
||||||
|
|
||||||
def create_episodes_from_messages(input_message, input_previous_messages):
|
def create_episodes_from_messages(input_message, input_previous_messages):
|
||||||
"""
|
"""
|
||||||
|
|
@ -11,28 +14,31 @@ def create_episodes_from_messages(input_message, input_previous_messages):
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
|
|
||||||
# Create the current episode
|
# Create the current episode
|
||||||
role = input_message["role"]
|
role = input_message['role']
|
||||||
content = input_message["content"]
|
content = input_message['content']
|
||||||
message_content = f"{role}: {content}"
|
message_content = f'{role}: {content}'
|
||||||
episode = EpisodicNode(
|
episode = EpisodicNode(
|
||||||
name="",
|
name='',
|
||||||
group_id="",
|
group_id='',
|
||||||
source=EpisodeType.message,
|
source=EpisodeType.message,
|
||||||
type=EpisodeType.message,
|
type=EpisodeType.message,
|
||||||
source_description="",
|
source_description='',
|
||||||
content=message_content,
|
content=message_content,
|
||||||
valid_at=current_time,
|
valid_at=current_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create previous episodes
|
# Create previous episodes
|
||||||
num_previous_messages = len(input_previous_messages)
|
num_previous_messages = len(input_previous_messages)
|
||||||
previous_times = [current_time - timedelta(minutes=num_previous_messages - i) for i in range(num_previous_messages)]
|
previous_times = [
|
||||||
|
current_time - timedelta(minutes=num_previous_messages - i)
|
||||||
|
for i in range(num_previous_messages)
|
||||||
|
]
|
||||||
previous_episodes = [
|
previous_episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
name="",
|
name='',
|
||||||
group_id="",
|
group_id='',
|
||||||
source=EpisodeType.message,
|
source=EpisodeType.message,
|
||||||
source_description="",
|
source_description='',
|
||||||
content=f"{message['role']}: {message['content']}",
|
content=f"{message['role']}: {message['content']}",
|
||||||
valid_at=previous_time,
|
valid_at=previous_time,
|
||||||
)
|
)
|
||||||
|
|
@ -41,10 +47,14 @@ def create_episodes_from_messages(input_message, input_previous_messages):
|
||||||
|
|
||||||
return episode, previous_episodes
|
return episode, previous_episodes
|
||||||
|
|
||||||
|
|
||||||
async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
||||||
# Import necessary functions
|
# Import necessary functions
|
||||||
from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
|
|
||||||
from graphiti_core.utils.maintenance.edge_operations import extract_edges
|
from graphiti_core.utils.maintenance.edge_operations import extract_edges
|
||||||
|
from graphiti_core.utils.maintenance.node_operations import (
|
||||||
|
extract_nodes,
|
||||||
|
resolve_extracted_nodes,
|
||||||
|
)
|
||||||
|
|
||||||
# Loop through each unique message_index_within_snippet in sorted order
|
# Loop through each unique message_index_within_snippet in sorted order
|
||||||
for message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
|
for message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
|
||||||
|
|
@ -52,16 +62,24 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
||||||
|
|
||||||
#### Process 'extract_nodes' task
|
#### Process 'extract_nodes' task
|
||||||
extract_nodes_row = message_df[message_df['task_name'] == 'extract_nodes']
|
extract_nodes_row = message_df[message_df['task_name'] == 'extract_nodes']
|
||||||
assert len(extract_nodes_row) == 1, f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
|
assert (
|
||||||
|
len(extract_nodes_row) == 1
|
||||||
|
), f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
|
||||||
input_message = json.loads(extract_nodes_row.iloc[0]['input_message'])
|
input_message = json.loads(extract_nodes_row.iloc[0]['input_message'])
|
||||||
input_previous_messages = json.loads(extract_nodes_row.iloc[0]['input_previous_messages'])
|
input_previous_messages = json.loads(extract_nodes_row.iloc[0]['input_previous_messages'])
|
||||||
episode, previous_episodes = create_episodes_from_messages(input_message, input_previous_messages)
|
episode, previous_episodes = create_episodes_from_messages(
|
||||||
|
input_message, input_previous_messages
|
||||||
|
)
|
||||||
extracted_nodes = await extract_nodes(llm_client, episode, previous_episodes)
|
extracted_nodes = await extract_nodes(llm_client, episode, previous_episodes)
|
||||||
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in extracted_nodes])
|
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps(
|
||||||
|
[entity_to_dict(node) for node in extracted_nodes]
|
||||||
|
)
|
||||||
|
|
||||||
#### Process 'dedupe_nodes' task
|
#### Process 'dedupe_nodes' task
|
||||||
dedupe_nodes_row = message_df[message_df['task_name'] == 'dedupe_nodes']
|
dedupe_nodes_row = message_df[message_df['task_name'] == 'dedupe_nodes']
|
||||||
assert len(dedupe_nodes_row) == 1, "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"
|
assert (
|
||||||
|
len(dedupe_nodes_row) == 1
|
||||||
|
), "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"
|
||||||
|
|
||||||
# Calculate existing nodes list
|
# Calculate existing nodes list
|
||||||
existing_nodes = []
|
existing_nodes = []
|
||||||
|
|
@ -71,8 +89,8 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
||||||
|
|
||||||
# Filter for previous messages with 'extract_nodes' task
|
# Filter for previous messages with 'extract_nodes' task
|
||||||
prev_message_df = snippet_df[
|
prev_message_df = snippet_df[
|
||||||
(snippet_df['message_index_within_snippet'] == prev_message_index) &
|
(snippet_df['message_index_within_snippet'] == prev_message_index)
|
||||||
(snippet_df['task_name'] == 'extract_nodes')
|
& (snippet_df['task_name'] == 'extract_nodes')
|
||||||
]
|
]
|
||||||
|
|
||||||
# Retrieve and deserialize the nodes
|
# Retrieve and deserialize the nodes
|
||||||
|
|
@ -82,12 +100,18 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
||||||
existing_nodes.extend(nodes)
|
existing_nodes.extend(nodes)
|
||||||
|
|
||||||
existing_nodes_lists = [existing_nodes for _ in range(len(extracted_nodes))]
|
existing_nodes_lists = [existing_nodes for _ in range(len(extracted_nodes))]
|
||||||
resolved_nodes, uuid_map = await resolve_extracted_nodes(llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes)
|
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
||||||
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in resolved_nodes])
|
llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes
|
||||||
|
)
|
||||||
|
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps(
|
||||||
|
[entity_to_dict(node) for node in resolved_nodes]
|
||||||
|
)
|
||||||
|
|
||||||
#### Process 'extract_edges' task
|
#### Process 'extract_edges' task
|
||||||
extract_edges_row = message_df[message_df['task_name'] == 'extract_edges']
|
extract_edges_row = message_df[message_df['task_name'] == 'extract_edges']
|
||||||
assert len(extract_edges_row) == 1, f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
|
assert (
|
||||||
|
len(extract_edges_row) == 1
|
||||||
|
), f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
|
||||||
extracted_edges = await extract_edges(
|
extracted_edges = await extract_edges(
|
||||||
llm_client,
|
llm_client,
|
||||||
episode,
|
episode,
|
||||||
|
|
@ -95,7 +119,9 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
group_id='',
|
group_id='',
|
||||||
)
|
)
|
||||||
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps([entity_to_dict(edge) for edge in extracted_edges])
|
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps(
|
||||||
|
[entity_to_dict(edge) for edge in extracted_edges]
|
||||||
|
)
|
||||||
|
|
||||||
########## TODO: Complete the implementation of the below
|
########## TODO: Complete the implementation of the below
|
||||||
|
|
||||||
|
|
@ -131,17 +157,20 @@ async def ingest_and_label_minidataset(llm_client, minidataset_df, output_column
|
||||||
minidataset_labelled_df = None
|
minidataset_labelled_df = None
|
||||||
for snippet_index in sorted(minidataset_df['snippet_index'].unique()):
|
for snippet_index in sorted(minidataset_df['snippet_index'].unique()):
|
||||||
snippet_df = minidataset_df[minidataset_df['snippet_index'] == snippet_index]
|
snippet_df = minidataset_df[minidataset_df['snippet_index'] == snippet_index]
|
||||||
|
|
||||||
# Pass the output column name to the ingest_and_label_snippet function
|
# Pass the output column name to the ingest_and_label_snippet function
|
||||||
snippet_df_labelled = await ingest_and_label_snippet(llm_client, snippet_df, output_column_name)
|
snippet_df_labelled = await ingest_and_label_snippet(
|
||||||
|
llm_client, snippet_df, output_column_name
|
||||||
|
)
|
||||||
|
|
||||||
if minidataset_labelled_df is None:
|
if minidataset_labelled_df is None:
|
||||||
minidataset_labelled_df = snippet_df_labelled
|
minidataset_labelled_df = snippet_df_labelled
|
||||||
else:
|
else:
|
||||||
minidataset_labelled_df = pd.concat([minidataset_labelled_df, snippet_df_labelled])
|
minidataset_labelled_df = pd.concat([minidataset_labelled_df, snippet_df_labelled])
|
||||||
|
|
||||||
return minidataset_labelled_df
|
return minidataset_labelled_df
|
||||||
|
|
||||||
|
|
||||||
def entity_to_dict(entity):
|
def entity_to_dict(entity):
|
||||||
"""
|
"""
|
||||||
Convert an entity object to a dictionary, handling datetime serialization.
|
Convert an entity object to a dictionary, handling datetime serialization.
|
||||||
|
|
@ -152,6 +181,7 @@ def entity_to_dict(entity):
|
||||||
entity_dict[key] = value.isoformat() # Convert datetime to ISO 8601 string
|
entity_dict[key] = value.isoformat() # Convert datetime to ISO 8601 string
|
||||||
return entity_dict
|
return entity_dict
|
||||||
|
|
||||||
|
|
||||||
def dict_to_entity(entity_dict, entity_class):
|
def dict_to_entity(entity_dict, entity_class):
|
||||||
"""
|
"""
|
||||||
Convert a dictionary back to an entity object, handling datetime deserialization.
|
Convert a dictionary back to an entity object, handling datetime deserialization.
|
||||||
|
|
@ -163,4 +193,4 @@ def dict_to_entity(entity_dict, entity_class):
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
# If parsing fails, keep the original value
|
# If parsing fails, keep the original value
|
||||||
pass
|
pass
|
||||||
return entity_class(**entity_dict)
|
return entity_class(**entity_dict)
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import csv # Add this import at the top of the file
|
||||||
import json
|
import json
|
||||||
from tests.evals.utils import setup_logging, ingest_snippet
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -24,17 +24,11 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.graphiti import Graphiti
|
from graphiti_core.graphiti import Graphiti
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
||||||
|
|
||||||
from graphiti_core.utils.maintenance.node_operations import extract_nodes
|
|
||||||
from graphiti_core.llm_client import OpenAIClient
|
from graphiti_core.llm_client import OpenAIClient
|
||||||
from graphiti_core.llm_client.config import LLMConfig
|
from graphiti_core.llm_client.config import LLMConfig
|
||||||
from graphiti_core.nodes import EpisodeType
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
|
from graphiti_core.utils.maintenance.node_operations import extract_nodes
|
||||||
import csv # Add this import at the top of the file
|
from tests.evals.utils import ingest_snippet, setup_logging
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
############# EVERYTHING BELOW IS OUTDATED
|
############# EVERYTHING BELOW IS OUTDATED
|
||||||
|
|
||||||
|
|
@ -63,68 +57,67 @@ async def general_extract_nodes_test(llm_client, data_sample):
|
||||||
return hypothesis_node_names
|
return hypothesis_node_names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx):
|
def prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx):
|
||||||
|
samples_csv_path = 'tests/evals/data/' + data_file_name + '.csv'
|
||||||
samples_csv_path = "tests/evals/data/" + data_file_name + ".csv"
|
|
||||||
|
|
||||||
# From CSV path, load everything
|
# From CSV path, load everything
|
||||||
with open(samples_csv_path, 'r') as file:
|
with open(samples_csv_path, 'r') as file:
|
||||||
csv_reader = csv.DictReader(file)
|
csv_reader = csv.DictReader(file)
|
||||||
lme_samples = list(csv_reader)
|
lme_samples = list(csv_reader)
|
||||||
|
|
||||||
|
|
||||||
data_samples = []
|
data_samples = []
|
||||||
|
|
||||||
# Loop through each row
|
# Loop through each row
|
||||||
for row in lme_samples:
|
for row in lme_samples:
|
||||||
|
|
||||||
### Prepare episode
|
### Prepare episode
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
message = json.loads(row["message"])
|
message = json.loads(row['message'])
|
||||||
role = message["role"]
|
role = message['role']
|
||||||
content = message["content"]
|
content = message['content']
|
||||||
message_content = role + ": " + content
|
message_content = role + ': ' + content
|
||||||
episode = EpisodicNode(
|
episode = EpisodicNode(
|
||||||
name="",
|
name='',
|
||||||
group_id="",
|
group_id='',
|
||||||
source=EpisodeType.message,
|
source=EpisodeType.message,
|
||||||
type=EpisodeType.message,
|
type=EpisodeType.message,
|
||||||
source_description="",
|
source_description='',
|
||||||
content=message_content,
|
content=message_content,
|
||||||
valid_at=current_time,
|
valid_at=current_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
### Prepare previous episodes
|
### Prepare previous episodes
|
||||||
previous_messages = json.loads(row["previous_messages"])
|
previous_messages = json.loads(row['previous_messages'])
|
||||||
num_previous_messages = len(previous_messages)
|
num_previous_messages = len(previous_messages)
|
||||||
previous_times = [current_time - timedelta(minutes=num_previous_messages-i) for i in range(num_previous_messages)]
|
previous_times = [
|
||||||
previous_episodes = [EpisodicNode(
|
current_time - timedelta(minutes=num_previous_messages - i)
|
||||||
name="",
|
for i in range(num_previous_messages)
|
||||||
group_id="",
|
]
|
||||||
source=EpisodeType.message,
|
previous_episodes = [
|
||||||
source_description="",
|
EpisodicNode(
|
||||||
content=message["role"] + ": " + message["content"],
|
name='',
|
||||||
valid_at=previous_time,
|
group_id='',
|
||||||
) for message, previous_time in zip(previous_messages, previous_times)]
|
source=EpisodeType.message,
|
||||||
|
source_description='',
|
||||||
|
content=message['role'] + ': ' + message['content'],
|
||||||
|
valid_at=previous_time,
|
||||||
|
)
|
||||||
|
for message, previous_time in zip(previous_messages, previous_times)
|
||||||
|
]
|
||||||
|
|
||||||
### TODO: Prepare gold answer names
|
### TODO: Prepare gold answer names
|
||||||
|
|
||||||
### Add to data samples list
|
### Add to data samples list
|
||||||
data_samples.append({
|
data_samples.append(
|
||||||
"episode": episode,
|
{
|
||||||
"previous_episodes": previous_episodes,
|
'episode': episode,
|
||||||
"gold_answer_names": [],
|
'previous_episodes': previous_episodes,
|
||||||
})
|
'gold_answer_names': [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return data_samples
|
return data_samples
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_nodes():
|
async def test_extract_nodes():
|
||||||
model_name = 'gpt-4o-mini'
|
model_name = 'gpt-4o-mini'
|
||||||
|
|
@ -135,17 +128,16 @@ async def test_extract_nodes():
|
||||||
llm_client = OpenAIClient(config=llm_config)
|
llm_client = OpenAIClient(config=llm_config)
|
||||||
|
|
||||||
data_file_name = 'output_short'
|
data_file_name = 'output_short'
|
||||||
question_id = "gpt4_2655b836"
|
question_id = 'gpt4_2655b836'
|
||||||
session_idx = 0
|
session_idx = 0
|
||||||
message_idx = 0
|
message_idx = 0
|
||||||
data_samples = prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx)
|
data_samples = prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx)
|
||||||
|
|
||||||
for data_sample in data_samples:
|
for data_sample in data_samples:
|
||||||
print(f"\n\nEpisode: {data_sample['episode']}")
|
print(f"\n\nEpisode: {data_sample['episode']}")
|
||||||
print("*"*50)
|
print('*' * 50)
|
||||||
print(f"Previous Episodes: {data_sample['previous_episodes']}")
|
print(f"Previous Episodes: {data_sample['previous_episodes']}")
|
||||||
print("*"*50)
|
print('*' * 50)
|
||||||
# print(f"Gold Answer Names: {gold_answer_names}")
|
# print(f"Gold Answer Names: {gold_answer_names}")
|
||||||
|
|
||||||
await general_extract_nodes_test(llm_client, data_sample)
|
await general_extract_nodes_test(llm_client, data_sample)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue