From a1e54881a2d0a29a446170140939a8f3eb151b2b Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Fri, 23 Aug 2024 12:17:15 -0400 Subject: [PATCH] improve deduping issue (#28) * improve deduping issue * fix comment * commit format * default embeddings * update --- core/prompts/dedupe_edges.py | 108 +++++----- core/utils/bulk_utils.py | 239 ++++++++++++---------- core/utils/maintenance/edge_operations.py | 34 +-- examples/podcast/podcast_runner.py | 4 +- 4 files changed, 199 insertions(+), 186 deletions(-) diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py index cb41f201..4827ee43 100644 --- a/core/prompts/dedupe_edges.py +++ b/core/prompts/dedupe_edges.py @@ -5,66 +5,64 @@ from .models import Message, PromptFunction, PromptVersion class Prompt(Protocol): - v1: PromptVersion - v2: PromptVersion - edge_list: PromptVersion + v1: PromptVersion + v2: PromptVersion + edge_list: PromptVersion class Versions(TypedDict): - v1: PromptFunction - v2: PromptFunction - edge_list: PromptFunction + v1: PromptFunction + v2: PromptFunction + edge_list: PromptFunction def v1(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that de-duplicates relationship from edge lists.', - ), - Message( - role='user', - content=f""" - Given the following context, deduplicate edges from a list of new edges given a list of existing edges: + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates relationship from edge lists.', + ), + Message( + role='user', + content=f""" + Given the following context, deduplicate facts from a list of new facts given a list of existing facts: - Existing Edges: + Existing Facts: {json.dumps(context['existing_edges'], indent=2)} - New Edges: + New Facts: {json.dumps(context['extracted_edges'], indent=2)} Task: - 1. start with the list of edges from New Edges - 2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing - edge in the list - 3. Respond with the resulting list of edges + If any facts in New Facts is a duplicate of a fact in Existing Facts, + do not return it in the list of unique facts. Guidelines: - 1. Use both the name and fact of edges to determine if they are duplicates, - duplicate edges may have different names + 1. The facts do not have to be completely identical to be duplicates, + they just need to have similar factual content Respond with a JSON object in the following format: {{ - "new_edges": [ + "unique_facts": [ {{ - "fact": "one sentence description of the fact" + "uuid": "unique identifier of the fact" }} ] }} """, - ), - ] + ), + ] def v2(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that de-duplicates relationship from edge lists.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates relationship from edge lists.', + ), + Message( + role='user', + content=f""" Given the following context, deduplicate edges from a list of new edges given a list of existing edges: Existing Edges: @@ -94,44 +92,44 @@ def v2(context: dict[str, Any]) -> list[Message]: ] }} """, - ), - ] + ), + ] def edge_list(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that de-duplicates edges from edge lists.', - ), - Message( - role='user', - content=f""" - Given the following context, find all of the duplicates in a list of edges: + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates edges from edge lists.', + ), + Message( + role='user', + content=f""" + Given the following context, find all of the duplicates in a list of facts: - Edges: + Facts: {json.dumps(context['edges'], indent=2)} Task: - If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges + If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's. Guidelines: - 1. Use both the name and fact of edges to determine if they are duplicates, - edges with the same name may not be duplicates - 2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their + 1. The facts do not have to be completely identical to be duplicates, they just need to have similar content + 2. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their facts should be in the response Respond with a JSON object in the following format: {{ - "unique_edges": [ + "unique_facts": [ {{ - "fact": "fact of a unique edge", + "uuid": "unique identifier of the fact", + "fact": "fact of a unique edge" }} ] }} """, - ), - ] + ), + ] versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list} diff --git a/core/utils/bulk_utils.py b/core/utils/bulk_utils.py index fbffd3e3..197d4aa3 100644 --- a/core/utils/bulk_utils.py +++ b/core/utils/bulk_utils.py @@ -3,6 +3,7 @@ import typing from datetime import datetime from neo4j import AsyncDriver +from numpy import dot from pydantic import BaseModel from core.edges import Edge, EntityEdge, EpisodicEdge @@ -11,186 +12,198 @@ from core.nodes import EntityNode, EpisodicNode from core.search.search_utils import get_relevant_edges, get_relevant_nodes from core.utils import retrieve_episodes from core.utils.maintenance.edge_operations import ( - build_episodic_edges, - dedupe_edge_list, - dedupe_extracted_edges, - extract_edges, + build_episodic_edges, + dedupe_edge_list, + dedupe_extracted_edges, + extract_edges, ) from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN from core.utils.maintenance.node_operations import ( - dedupe_extracted_nodes, - dedupe_node_list, - extract_nodes, + dedupe_extracted_nodes, + dedupe_node_list, + extract_nodes, ) -CHUNK_SIZE = 10 +CHUNK_SIZE = 15 class BulkEpisode(BaseModel): - name: str - content: str - source_description: str - episode_type: str - reference_time: datetime + name: str + content: str + source_description: str + episode_type: str + reference_time: datetime async def retrieve_previous_episodes_bulk( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: - previous_episodes_list = await asyncio.gather( - *[ - retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN) - for episode in episodes - ] - ) - episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [ - (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes) - ] + previous_episodes_list = await asyncio.gather( + *[ + retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN) + for episode in episodes + ] + ) + episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [ + (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes) + ] - return episode_tuples + return episode_tuples async def extract_nodes_and_edges_bulk( - llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] + llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: - extracted_nodes_bulk = await asyncio.gather( - *[ - extract_nodes(llm_client, episode, previous_episodes) - for episode, previous_episodes in episode_tuples - ] - ) + extracted_nodes_bulk = await asyncio.gather( + *[ + extract_nodes(llm_client, episode, previous_episodes) + for episode, previous_episodes in episode_tuples + ] + ) - episodes, previous_episodes_list = ( - [episode[0] for episode in episode_tuples], - [episode[1] for episode in episode_tuples], - ) + episodes, previous_episodes_list = ( + [episode[0] for episode in episode_tuples], + [episode[1] for episode in episode_tuples], + ) - extracted_edges_bulk = await asyncio.gather( - *[ - extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]) - for i, episode in enumerate(episodes) - ] - ) + extracted_edges_bulk = await asyncio.gather( + *[ + extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]) + for i, episode in enumerate(episodes) + ] + ) - episodic_edges: list[EpisodicEdge] = [] - for i, episode in enumerate(episodes): - episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at) + episodic_edges: list[EpisodicEdge] = [] + for i, episode in enumerate(episodes): + episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at) - nodes: list[EntityNode] = [] - for extracted_nodes in extracted_nodes_bulk: - nodes += extracted_nodes + nodes: list[EntityNode] = [] + for extracted_nodes in extracted_nodes_bulk: + nodes += extracted_nodes - edges: list[EntityEdge] = [] - for extracted_edges in extracted_edges_bulk: - edges += extracted_edges + edges: list[EntityEdge] = [] + for extracted_edges in extracted_edges_bulk: + edges += extracted_edges - return nodes, edges, episodic_edges + return nodes, edges, episodic_edges async def dedupe_nodes_bulk( - driver: AsyncDriver, - llm_client: LLMClient, - extracted_nodes: list[EntityNode], + driver: AsyncDriver, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - # Compress nodes - nodes, uuid_map = node_name_match(extracted_nodes) + # Compress nodes + nodes, uuid_map = node_name_match(extracted_nodes) - compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map) + compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map) - existing_nodes = await get_relevant_nodes(compressed_nodes, driver) + existing_nodes = await get_relevant_nodes(compressed_nodes, driver) - nodes, partial_uuid_map, _ = await dedupe_extracted_nodes( - llm_client, compressed_nodes, existing_nodes - ) + nodes, partial_uuid_map, _ = await dedupe_extracted_nodes( + llm_client, compressed_nodes, existing_nodes + ) - compressed_map.update(partial_uuid_map) + compressed_map.update(partial_uuid_map) - return nodes, compressed_map + return nodes, compressed_map async def dedupe_edges_bulk( - driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] + driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] ) -> list[EntityEdge]: - # Compress edges - compressed_edges = await compress_edges(llm_client, extracted_edges) + # Compress edges + compressed_edges = await compress_edges(llm_client, extracted_edges) - existing_edges = await get_relevant_edges(compressed_edges, driver) + existing_edges = await get_relevant_edges(compressed_edges, driver) - edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges) + edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges) - return edges + return edges def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]: - uuid_map: dict[str, str] = {} - name_map: dict[str, EntityNode] = {} - for node in nodes: - if node.name in name_map: - uuid_map[node.uuid] = name_map[node.name].uuid - continue + uuid_map: dict[str, str] = {} + name_map: dict[str, EntityNode] = {} + for node in nodes: + if node.name in name_map: + uuid_map[node.uuid] = name_map[node.name].uuid + continue - name_map[node.name] = node + name_map[node.name] = node - return [node for node in name_map.values()], uuid_map + return [node for node in name_map.values()], uuid_map async def compress_nodes( - llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] + llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] ) -> tuple[list[EntityNode], dict[str, str]]: - node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] + if len(nodes) == 0: + return nodes, uuid_map - results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) + anchor = nodes[0] + nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or [])) - extended_map = dict(uuid_map) - compressed_nodes: list[EntityNode] = [] - for node_chunk, uuid_map_chunk in results: - compressed_nodes += node_chunk - extended_map.update(uuid_map_chunk) + node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] - # Check if we have removed all duplicates - if len(compressed_nodes) == len(nodes): - compressed_uuid_map = compress_uuid_map(extended_map) - return compressed_nodes, compressed_uuid_map + results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) - return await compress_nodes(llm_client, compressed_nodes, extended_map) + extended_map = dict(uuid_map) + compressed_nodes: list[EntityNode] = [] + for node_chunk, uuid_map_chunk in results: + compressed_nodes += node_chunk + extended_map.update(uuid_map_chunk) + + # Check if we have removed all duplicates + if len(compressed_nodes) == len(nodes): + compressed_uuid_map = compress_uuid_map(extended_map) + return compressed_nodes, compressed_uuid_map + + return await compress_nodes(llm_client, compressed_nodes, extended_map) async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]: - edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] + if len(edges) == 0: + return edges - results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) + anchor = edges[0] + edges.sort(key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])) - compressed_edges: list[EntityEdge] = [] - for edge_chunk in results: - compressed_edges += edge_chunk + edge_chunks = [edges[i: i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] - # Check if we have removed all duplicates - if len(compressed_edges) == len(edges): - return compressed_edges + results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) - return await compress_edges(llm_client, compressed_edges) + compressed_edges: list[EntityEdge] = [] + for edge_chunk in results: + compressed_edges += edge_chunk + + # Check if we have removed all duplicates + if len(compressed_edges) == len(edges): + return compressed_edges + + return await compress_edges(llm_client, compressed_edges) def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]: - # make sure all uuid values aren't mapped to other uuids - compressed_map = {} - for key, uuid in uuid_map.items(): - curr_value = uuid - while curr_value in uuid_map: - curr_value = uuid_map[curr_value] + # make sure all uuid values aren't mapped to other uuids + compressed_map = {} + for key, uuid in uuid_map.items(): + curr_value = uuid + while curr_value in uuid_map: + curr_value = uuid_map[curr_value] - compressed_map[key] = curr_value - return compressed_map + compressed_map[key] = curr_value + return compressed_map E = typing.TypeVar('E', bound=Edge) def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]): - for edge in edges: - source_uuid = edge.source_node_uuid - target_uuid = edge.target_node_uuid - edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid) - edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid) + for edge in edges: + source_uuid = edge.source_node_uuid + target_uuid = edge.target_node_uuid + edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid) + edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid) - return edges + return edges diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index 922ba1a6..bdcfef67 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -94,27 +94,27 @@ async def dedupe_extracted_edges( ) -> list[EntityEdge]: # Create edge map edge_map = {} - for edge in existing_edges: - edge_map[edge.fact] = edge for edge in extracted_edges: - if edge.fact in edge_map: - continue - edge_map[edge.fact] = edge + edge_map[edge.uuid] = edge # Prepare context for LLM context = { - 'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges], - 'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges], + 'extracted_edges': [ + {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges + ], + 'existing_edges': [ + {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges + ], } llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context)) - new_edges_data = llm_response.get('new_edges', []) - logger.info(f'Extracted new edges: {new_edges_data}') + unique_edge_data = llm_response.get('unique_facts', []) + logger.info(f'Extracted unique edges: {unique_edge_data}') # Get full edge data edges = [] - for edge_data in new_edges_data: - edge = edge_map[edge_data['fact']] + for unique_edge in unique_edge_data: + edge = edge_map[unique_edge['uuid']] edges.append(edge) return edges @@ -129,15 +129,15 @@ async def dedupe_edge_list( # Create edge map edge_map = {} for edge in edges: - edge_map[edge.fact] = edge + edge_map[edge.uuid] = edge # Prepare context for LLM - context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]} + context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]} llm_response = await llm_client.generate_response( prompt_library.dedupe_edges.edge_list(context) ) - unique_edges_data = llm_response.get('unique_edges', []) + unique_edges_data = llm_response.get('unique_facts', []) end = time() logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ') @@ -145,7 +145,9 @@ async def dedupe_edge_list( # Get full edge data unique_edges = [] for edge_data in unique_edges_data: - fact = edge_data['fact'] - unique_edges.append(edge_map[fact]) + uuid = edge_data['uuid'] + edge = edge_map[uuid] + edge.fact = edge_data['fact'] + unique_edges.append(edge) return unique_edges diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index cdb9fc9b..e8f99996 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -62,10 +62,10 @@ async def main(use_bulk: bool = True): episode_type='string', reference_time=message.actual_timestamp, ) - for i, message in enumerate(messages[3:7]) + for i, message in enumerate(messages[3:14]) ] await client.add_episode_bulk(episodes) -asyncio.run(main()) +asyncio.run(main(True))