improve deduping issue (#28)

* improve deduping issue

* fix comment

* commit format

* default embeddings

* update
This commit is contained in:
Preston Rasmussen 2024-08-23 12:17:15 -04:00 committed by GitHub
parent 9cc9883e66
commit a1e54881a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 199 additions and 186 deletions

View file

@ -5,66 +5,64 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
v2: PromptVersion v2: PromptVersion
edge_list: PromptVersion edge_list: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
v2: PromptFunction v2: PromptFunction
edge_list: PromptFunction edge_list: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]: def v1(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
role='system', role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.', content='You are a helpful assistant that de-duplicates relationship from edge lists.',
), ),
Message( Message(
role='user', role='user',
content=f""" content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges: Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Existing Edges: Existing Facts:
{json.dumps(context['existing_edges'], indent=2)} {json.dumps(context['existing_edges'], indent=2)}
New Edges: New Facts:
{json.dumps(context['extracted_edges'], indent=2)} {json.dumps(context['extracted_edges'], indent=2)}
Task: Task:
1. start with the list of edges from New Edges If any facts in New Facts is a duplicate of a fact in Existing Facts,
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing do not return it in the list of unique facts.
edge in the list
3. Respond with the resulting list of edges
Guidelines: Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates, 1. The facts do not have to be completely identical to be duplicates,
duplicate edges may have different names they just need to have similar factual content
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
"new_edges": [ "unique_facts": [
{{ {{
"fact": "one sentence description of the fact" "uuid": "unique identifier of the fact"
}} }}
] ]
}} }}
""", """,
), ),
] ]
def v2(context: dict[str, Any]) -> list[Message]: def v2(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
role='system', role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.', content='You are a helpful assistant that de-duplicates relationship from edge lists.',
), ),
Message( Message(
role='user', role='user',
content=f""" content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges: Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Existing Edges: Existing Edges:
@ -94,44 +92,44 @@ def v2(context: dict[str, Any]) -> list[Message]:
] ]
}} }}
""", """,
), ),
] ]
def edge_list(context: dict[str, Any]) -> list[Message]: def edge_list(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
role='system', role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.', content='You are a helpful assistant that de-duplicates edges from edge lists.',
), ),
Message( Message(
role='user', role='user',
content=f""" content=f"""
Given the following context, find all of the duplicates in a list of edges: Given the following context, find all of the duplicates in a list of facts:
Edges: Facts:
{json.dumps(context['edges'], indent=2)} {json.dumps(context['edges'], indent=2)}
Task: Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
Guidelines: Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates, 1. The facts do not have to be completely identical to be duplicates, they just need to have similar content
edges with the same name may not be duplicates 2. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
facts should be in the response facts should be in the response
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
"unique_edges": [ "unique_facts": [
{{ {{
"fact": "fact of a unique edge", "uuid": "unique identifier of the fact",
"fact": "fact of a unique edge"
}} }}
] ]
}} }}
""", """,
), ),
] ]
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list} versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}

View file

@ -3,6 +3,7 @@ import typing
from datetime import datetime from datetime import datetime
from neo4j import AsyncDriver from neo4j import AsyncDriver
from numpy import dot
from pydantic import BaseModel from pydantic import BaseModel
from core.edges import Edge, EntityEdge, EpisodicEdge from core.edges import Edge, EntityEdge, EpisodicEdge
@ -11,186 +12,198 @@ from core.nodes import EntityNode, EpisodicNode
from core.search.search_utils import get_relevant_edges, get_relevant_nodes from core.search.search_utils import get_relevant_edges, get_relevant_nodes
from core.utils import retrieve_episodes from core.utils import retrieve_episodes
from core.utils.maintenance.edge_operations import ( from core.utils.maintenance.edge_operations import (
build_episodic_edges, build_episodic_edges,
dedupe_edge_list, dedupe_edge_list,
dedupe_extracted_edges, dedupe_extracted_edges,
extract_edges, extract_edges,
) )
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.node_operations import ( from core.utils.maintenance.node_operations import (
dedupe_extracted_nodes, dedupe_extracted_nodes,
dedupe_node_list, dedupe_node_list,
extract_nodes, extract_nodes,
) )
CHUNK_SIZE = 10 CHUNK_SIZE = 15
class BulkEpisode(BaseModel): class BulkEpisode(BaseModel):
name: str name: str
content: str content: str
source_description: str source_description: str
episode_type: str episode_type: str
reference_time: datetime reference_time: datetime
async def retrieve_previous_episodes_bulk( async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode] driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather( previous_episodes_list = await asyncio.gather(
*[ *[
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN) retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
for episode in episodes for episode in episodes
] ]
) )
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [ episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes) (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
] ]
return episode_tuples return episode_tuples
async def extract_nodes_and_edges_bulk( async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await asyncio.gather( extracted_nodes_bulk = await asyncio.gather(
*[ *[
extract_nodes(llm_client, episode, previous_episodes) extract_nodes(llm_client, episode, previous_episodes)
for episode, previous_episodes in episode_tuples for episode, previous_episodes in episode_tuples
] ]
) )
episodes, previous_episodes_list = ( episodes, previous_episodes_list = (
[episode[0] for episode in episode_tuples], [episode[0] for episode in episode_tuples],
[episode[1] for episode in episode_tuples], [episode[1] for episode in episode_tuples],
) )
extracted_edges_bulk = await asyncio.gather( extracted_edges_bulk = await asyncio.gather(
*[ *[
extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]) extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i])
for i, episode in enumerate(episodes) for i, episode in enumerate(episodes)
] ]
) )
episodic_edges: list[EpisodicEdge] = [] episodic_edges: list[EpisodicEdge] = []
for i, episode in enumerate(episodes): for i, episode in enumerate(episodes):
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at) episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
nodes: list[EntityNode] = [] nodes: list[EntityNode] = []
for extracted_nodes in extracted_nodes_bulk: for extracted_nodes in extracted_nodes_bulk:
nodes += extracted_nodes nodes += extracted_nodes
edges: list[EntityEdge] = [] edges: list[EntityEdge] = []
for extracted_edges in extracted_edges_bulk: for extracted_edges in extracted_edges_bulk:
edges += extracted_edges edges += extracted_edges
return nodes, edges, episodic_edges return nodes, edges, episodic_edges
async def dedupe_nodes_bulk( async def dedupe_nodes_bulk(
driver: AsyncDriver, driver: AsyncDriver,
llm_client: LLMClient, llm_client: LLMClient,
extracted_nodes: list[EntityNode], extracted_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]: ) -> tuple[list[EntityNode], dict[str, str]]:
# Compress nodes # Compress nodes
nodes, uuid_map = node_name_match(extracted_nodes) nodes, uuid_map = node_name_match(extracted_nodes)
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map) compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
existing_nodes = await get_relevant_nodes(compressed_nodes, driver) existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes( nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes llm_client, compressed_nodes, existing_nodes
) )
compressed_map.update(partial_uuid_map) compressed_map.update(partial_uuid_map)
return nodes, compressed_map return nodes, compressed_map
async def dedupe_edges_bulk( async def dedupe_edges_bulk(
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# Compress edges # Compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges) compressed_edges = await compress_edges(llm_client, extracted_edges)
existing_edges = await get_relevant_edges(compressed_edges, driver) existing_edges = await get_relevant_edges(compressed_edges, driver)
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges) edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
return edges return edges
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]: def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {} uuid_map: dict[str, str] = {}
name_map: dict[str, EntityNode] = {} name_map: dict[str, EntityNode] = {}
for node in nodes: for node in nodes:
if node.name in name_map: if node.name in name_map:
uuid_map[node.uuid] = name_map[node.name].uuid uuid_map[node.uuid] = name_map[node.name].uuid
continue continue
name_map[node.name] = node name_map[node.name] = node
return [node for node in name_map.values()], uuid_map return [node for node in name_map.values()], uuid_map
async def compress_nodes( async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]: ) -> tuple[list[EntityNode], dict[str, str]]:
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] if len(nodes) == 0:
return nodes, uuid_map
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) anchor = nodes[0]
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
extended_map = dict(uuid_map) node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
compressed_nodes: list[EntityNode] = []
for node_chunk, uuid_map_chunk in results:
compressed_nodes += node_chunk
extended_map.update(uuid_map_chunk)
# Check if we have removed all duplicates results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
if len(compressed_nodes) == len(nodes):
compressed_uuid_map = compress_uuid_map(extended_map)
return compressed_nodes, compressed_uuid_map
return await compress_nodes(llm_client, compressed_nodes, extended_map) extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = []
for node_chunk, uuid_map_chunk in results:
compressed_nodes += node_chunk
extended_map.update(uuid_map_chunk)
# Check if we have removed all duplicates
if len(compressed_nodes) == len(nodes):
compressed_uuid_map = compress_uuid_map(extended_map)
return compressed_nodes, compressed_uuid_map
return await compress_nodes(llm_client, compressed_nodes, extended_map)
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]: async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] if len(edges) == 0:
return edges
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) anchor = edges[0]
edges.sort(key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or []))
compressed_edges: list[EntityEdge] = [] edge_chunks = [edges[i: i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
for edge_chunk in results:
compressed_edges += edge_chunk
# Check if we have removed all duplicates results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
if len(compressed_edges) == len(edges):
return compressed_edges
return await compress_edges(llm_client, compressed_edges) compressed_edges: list[EntityEdge] = []
for edge_chunk in results:
compressed_edges += edge_chunk
# Check if we have removed all duplicates
if len(compressed_edges) == len(edges):
return compressed_edges
return await compress_edges(llm_client, compressed_edges)
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]: def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
# make sure all uuid values aren't mapped to other uuids # make sure all uuid values aren't mapped to other uuids
compressed_map = {} compressed_map = {}
for key, uuid in uuid_map.items(): for key, uuid in uuid_map.items():
curr_value = uuid curr_value = uuid
while curr_value in uuid_map: while curr_value in uuid_map:
curr_value = uuid_map[curr_value] curr_value = uuid_map[curr_value]
compressed_map[key] = curr_value compressed_map[key] = curr_value
return compressed_map return compressed_map
E = typing.TypeVar('E', bound=Edge) E = typing.TypeVar('E', bound=Edge)
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]): def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
for edge in edges: for edge in edges:
source_uuid = edge.source_node_uuid source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid target_uuid = edge.target_node_uuid
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid) edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid) edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
return edges return edges

View file

@ -94,27 +94,27 @@ async def dedupe_extracted_edges(
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# Create edge map # Create edge map
edge_map = {} edge_map = {}
for edge in existing_edges:
edge_map[edge.fact] = edge
for edge in extracted_edges: for edge in extracted_edges:
if edge.fact in edge_map: edge_map[edge.uuid] = edge
continue
edge_map[edge.fact] = edge
# Prepare context for LLM # Prepare context for LLM
context = { context = {
'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges], 'extracted_edges': [
'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges], {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
],
'existing_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
],
} }
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context)) llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
new_edges_data = llm_response.get('new_edges', []) unique_edge_data = llm_response.get('unique_facts', [])
logger.info(f'Extracted new edges: {new_edges_data}') logger.info(f'Extracted unique edges: {unique_edge_data}')
# Get full edge data # Get full edge data
edges = [] edges = []
for edge_data in new_edges_data: for unique_edge in unique_edge_data:
edge = edge_map[edge_data['fact']] edge = edge_map[unique_edge['uuid']]
edges.append(edge) edges.append(edge)
return edges return edges
@ -129,15 +129,15 @@ async def dedupe_edge_list(
# Create edge map # Create edge map
edge_map = {} edge_map = {}
for edge in edges: for edge in edges:
edge_map[edge.fact] = edge edge_map[edge.uuid] = edge
# Prepare context for LLM # Prepare context for LLM
context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]} context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response( llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context) prompt_library.dedupe_edges.edge_list(context)
) )
unique_edges_data = llm_response.get('unique_edges', []) unique_edges_data = llm_response.get('unique_facts', [])
end = time() end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ') logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
@ -145,7 +145,9 @@ async def dedupe_edge_list(
# Get full edge data # Get full edge data
unique_edges = [] unique_edges = []
for edge_data in unique_edges_data: for edge_data in unique_edges_data:
fact = edge_data['fact'] uuid = edge_data['uuid']
unique_edges.append(edge_map[fact]) edge = edge_map[uuid]
edge.fact = edge_data['fact']
unique_edges.append(edge)
return unique_edges return unique_edges

View file

@ -62,10 +62,10 @@ async def main(use_bulk: bool = True):
episode_type='string', episode_type='string',
reference_time=message.actual_timestamp, reference_time=message.actual_timestamp,
) )
for i, message in enumerate(messages[3:7]) for i, message in enumerate(messages[3:14])
] ]
await client.add_episode_bulk(episodes) await client.add_episode_bulk(episodes)
asyncio.run(main()) asyncio.run(main(True))