improve deduping issue (#28)

* improve deduping issue

* fix comment

* commit format

* default embeddings

* update
This commit is contained in:
Preston Rasmussen 2024-08-23 12:17:15 -04:00 committed by GitHub
parent 9cc9883e66
commit a1e54881a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 199 additions and 186 deletions

View file

@ -25,29 +25,27 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message( Message(
role='user', role='user',
content=f""" content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges: Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Existing Edges: Existing Facts:
{json.dumps(context['existing_edges'], indent=2)} {json.dumps(context['existing_edges'], indent=2)}
New Edges: New Facts:
{json.dumps(context['extracted_edges'], indent=2)} {json.dumps(context['extracted_edges'], indent=2)}
Task: Task:
1. start with the list of edges from New Edges If any facts in New Facts is a duplicate of a fact in Existing Facts,
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing do not return it in the list of unique facts.
edge in the list
3. Respond with the resulting list of edges
Guidelines: Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates, 1. The facts do not have to be completely identical to be duplicates,
duplicate edges may have different names they just need to have similar factual content
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
"new_edges": [ "unique_facts": [
{{ {{
"fact": "one sentence description of the fact" "uuid": "unique identifier of the fact"
}} }}
] ]
}} }}
@ -107,25 +105,25 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
Message( Message(
role='user', role='user',
content=f""" content=f"""
Given the following context, find all of the duplicates in a list of edges: Given the following context, find all of the duplicates in a list of facts:
Edges: Facts:
{json.dumps(context['edges'], indent=2)} {json.dumps(context['edges'], indent=2)}
Task: Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
Guidelines: Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates, 1. The facts do not have to be completely identical to be duplicates, they just need to have similar content
edges with the same name may not be duplicates 2. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
facts should be in the response facts should be in the response
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
"unique_edges": [ "unique_facts": [
{{ {{
"fact": "fact of a unique edge", "uuid": "unique identifier of the fact",
"fact": "fact of a unique edge"
}} }}
] ]
}} }}

View file

@ -3,6 +3,7 @@ import typing
from datetime import datetime from datetime import datetime
from neo4j import AsyncDriver from neo4j import AsyncDriver
from numpy import dot
from pydantic import BaseModel from pydantic import BaseModel
from core.edges import Edge, EntityEdge, EpisodicEdge from core.edges import Edge, EntityEdge, EpisodicEdge
@ -23,7 +24,7 @@ from core.utils.maintenance.node_operations import (
extract_nodes, extract_nodes,
) )
CHUNK_SIZE = 10 CHUNK_SIZE = 15
class BulkEpisode(BaseModel): class BulkEpisode(BaseModel):
@ -137,7 +138,13 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
async def compress_nodes( async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]: ) -> tuple[list[EntityNode], dict[str, str]]:
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] if len(nodes) == 0:
return nodes, uuid_map
anchor = nodes[0]
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
@ -156,7 +163,13 @@ async def compress_nodes(
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]: async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] if len(edges) == 0:
return edges
anchor = edges[0]
edges.sort(key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or []))
edge_chunks = [edges[i: i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])

View file

@ -94,27 +94,27 @@ async def dedupe_extracted_edges(
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# Create edge map # Create edge map
edge_map = {} edge_map = {}
for edge in existing_edges:
edge_map[edge.fact] = edge
for edge in extracted_edges: for edge in extracted_edges:
if edge.fact in edge_map: edge_map[edge.uuid] = edge
continue
edge_map[edge.fact] = edge
# Prepare context for LLM # Prepare context for LLM
context = { context = {
'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges], 'extracted_edges': [
'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges], {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
],
'existing_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
],
} }
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context)) llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
new_edges_data = llm_response.get('new_edges', []) unique_edge_data = llm_response.get('unique_facts', [])
logger.info(f'Extracted new edges: {new_edges_data}') logger.info(f'Extracted unique edges: {unique_edge_data}')
# Get full edge data # Get full edge data
edges = [] edges = []
for edge_data in new_edges_data: for unique_edge in unique_edge_data:
edge = edge_map[edge_data['fact']] edge = edge_map[unique_edge['uuid']]
edges.append(edge) edges.append(edge)
return edges return edges
@ -129,15 +129,15 @@ async def dedupe_edge_list(
# Create edge map # Create edge map
edge_map = {} edge_map = {}
for edge in edges: for edge in edges:
edge_map[edge.fact] = edge edge_map[edge.uuid] = edge
# Prepare context for LLM # Prepare context for LLM
context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]} context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response( llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context) prompt_library.dedupe_edges.edge_list(context)
) )
unique_edges_data = llm_response.get('unique_edges', []) unique_edges_data = llm_response.get('unique_facts', [])
end = time() end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ') logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
@ -145,7 +145,9 @@ async def dedupe_edge_list(
# Get full edge data # Get full edge data
unique_edges = [] unique_edges = []
for edge_data in unique_edges_data: for edge_data in unique_edges_data:
fact = edge_data['fact'] uuid = edge_data['uuid']
unique_edges.append(edge_map[fact]) edge = edge_map[uuid]
edge.fact = edge_data['fact']
unique_edges.append(edge)
return unique_edges return unique_edges

View file

@ -62,10 +62,10 @@ async def main(use_bulk: bool = True):
episode_type='string', episode_type='string',
reference_time=message.actual_timestamp, reference_time=message.actual_timestamp,
) )
for i, message in enumerate(messages[3:7]) for i, message in enumerate(messages[3:14])
] ]
await client.add_episode_bulk(episodes) await client.add_episode_bulk(episodes)
asyncio.run(main()) asyncio.run(main(True))