improve deduping issue (#28)

* improve deduping issue

* fix comment

* commit format

* default embeddings

* update
This commit is contained in:
Preston Rasmussen 2024-08-23 12:17:15 -04:00 committed by GitHub
parent 9cc9883e66
commit a1e54881a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 199 additions and 186 deletions

View file

@ -25,29 +25,27 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Existing Edges:
Existing Facts:
{json.dumps(context['existing_edges'], indent=2)}
New Edges:
New Facts:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges
If any facts in New Facts is a duplicate of a fact in Existing Facts,
do not return it in the list of unique facts.
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
duplicate edges may have different names
1. The facts do not have to be completely identical to be duplicates,
they just need to have similar factual content
Respond with a JSON object in the following format:
{{
"new_edges": [
"unique_facts": [
{{
"fact": "one sentence description of the fact"
"uuid": "unique identifier of the fact"
}}
]
}}
@ -107,25 +105,25 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Given the following context, find all of the duplicates in a list of edges:
Given the following context, find all of the duplicates in a list of facts:
Edges:
Facts:
{json.dumps(context['edges'], indent=2)}
Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
edges with the same name may not be duplicates
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
1. The facts do not have to be completely identical to be duplicates, they just need to have similar content
2. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
facts should be in the response
Respond with a JSON object in the following format:
{{
"unique_edges": [
"unique_facts": [
{{
"fact": "fact of a unique edge",
"uuid": "unique identifier of the fact",
"fact": "fact of a unique edge"
}}
]
}}

View file

@ -3,6 +3,7 @@ import typing
from datetime import datetime
from neo4j import AsyncDriver
from numpy import dot
from pydantic import BaseModel
from core.edges import Edge, EntityEdge, EpisodicEdge
@ -23,7 +24,7 @@ from core.utils.maintenance.node_operations import (
extract_nodes,
)
CHUNK_SIZE = 10
CHUNK_SIZE = 15
class BulkEpisode(BaseModel):
@ -137,6 +138,12 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]:
if len(nodes) == 0:
return nodes, uuid_map
anchor = nodes[0]
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
@ -156,6 +163,12 @@ async def compress_nodes(
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
if len(edges) == 0:
return edges
anchor = edges[0]
edges.sort(key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or []))
edge_chunks = [edges[i: i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])

View file

@ -94,27 +94,27 @@ async def dedupe_extracted_edges(
) -> list[EntityEdge]:
# Create edge map
edge_map = {}
for edge in existing_edges:
edge_map[edge.fact] = edge
for edge in extracted_edges:
if edge.fact in edge_map:
continue
edge_map[edge.fact] = edge
edge_map[edge.uuid] = edge
# Prepare context for LLM
context = {
'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
'extracted_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
],
'existing_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
],
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
new_edges_data = llm_response.get('new_edges', [])
logger.info(f'Extracted new edges: {new_edges_data}')
unique_edge_data = llm_response.get('unique_facts', [])
logger.info(f'Extracted unique edges: {unique_edge_data}')
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data['fact']]
for unique_edge in unique_edge_data:
edge = edge_map[unique_edge['uuid']]
edges.append(edge)
return edges
@ -129,15 +129,15 @@ async def dedupe_edge_list(
# Create edge map
edge_map = {}
for edge in edges:
edge_map[edge.fact] = edge
edge_map[edge.uuid] = edge
# Prepare context for LLM
context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]}
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context)
)
unique_edges_data = llm_response.get('unique_edges', [])
unique_edges_data = llm_response.get('unique_facts', [])
end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
@ -145,7 +145,9 @@ async def dedupe_edge_list(
# Get full edge data
unique_edges = []
for edge_data in unique_edges_data:
fact = edge_data['fact']
unique_edges.append(edge_map[fact])
uuid = edge_data['uuid']
edge = edge_map[uuid]
edge.fact = edge_data['fact']
unique_edges.append(edge)
return unique_edges

View file

@ -62,10 +62,10 @@ async def main(use_bulk: bool = True):
episode_type='string',
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:7])
for i, message in enumerate(messages[3:14])
]
await client.add_episode_bulk(episodes)
asyncio.run(main())
asyncio.run(main(True))