improve deduping issue (#28)
* improve deduping issue * fix comment * commit format * default embeddings * update
This commit is contained in:
parent
9cc9883e66
commit
a1e54881a2
4 changed files with 199 additions and 186 deletions
|
|
@ -25,29 +25,27 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
Message(
|
Message(
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
|
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
|
||||||
|
|
||||||
Existing Edges:
|
Existing Facts:
|
||||||
{json.dumps(context['existing_edges'], indent=2)}
|
{json.dumps(context['existing_edges'], indent=2)}
|
||||||
|
|
||||||
New Edges:
|
New Facts:
|
||||||
{json.dumps(context['extracted_edges'], indent=2)}
|
{json.dumps(context['extracted_edges'], indent=2)}
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
1. start with the list of edges from New Edges
|
If any facts in New Facts is a duplicate of a fact in Existing Facts,
|
||||||
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
|
do not return it in the list of unique facts.
|
||||||
edge in the list
|
|
||||||
3. Respond with the resulting list of edges
|
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Use both the name and fact of edges to determine if they are duplicates,
|
1. The facts do not have to be completely identical to be duplicates,
|
||||||
duplicate edges may have different names
|
they just need to have similar factual content
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
"new_edges": [
|
"unique_facts": [
|
||||||
{{
|
{{
|
||||||
"fact": "one sentence description of the fact"
|
"uuid": "unique identifier of the fact"
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
@ -107,25 +105,25 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||||
Message(
|
Message(
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Given the following context, find all of the duplicates in a list of edges:
|
Given the following context, find all of the duplicates in a list of facts:
|
||||||
|
|
||||||
Edges:
|
Facts:
|
||||||
{json.dumps(context['edges'], indent=2)}
|
{json.dumps(context['edges'], indent=2)}
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
|
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Use both the name and fact of edges to determine if they are duplicates,
|
1. The facts do not have to be completely identical to be duplicates, they just need to have similar content
|
||||||
edges with the same name may not be duplicates
|
2. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
|
||||||
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
|
|
||||||
facts should be in the response
|
facts should be in the response
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
"unique_edges": [
|
"unique_facts": [
|
||||||
{{
|
{{
|
||||||
"fact": "fact of a unique edge",
|
"uuid": "unique identifier of the fact",
|
||||||
|
"fact": "fact of a unique edge"
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import typing
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
from numpy import dot
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.edges import Edge, EntityEdge, EpisodicEdge
|
from core.edges import Edge, EntityEdge, EpisodicEdge
|
||||||
|
|
@ -23,7 +24,7 @@ from core.utils.maintenance.node_operations import (
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
CHUNK_SIZE = 10
|
CHUNK_SIZE = 15
|
||||||
|
|
||||||
|
|
||||||
class BulkEpisode(BaseModel):
|
class BulkEpisode(BaseModel):
|
||||||
|
|
@ -137,7 +138,13 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
|
||||||
async def compress_nodes(
|
async def compress_nodes(
|
||||||
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
if len(nodes) == 0:
|
||||||
|
return nodes, uuid_map
|
||||||
|
|
||||||
|
anchor = nodes[0]
|
||||||
|
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
|
||||||
|
|
||||||
|
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||||
|
|
||||||
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
||||||
|
|
||||||
|
|
@ -156,7 +163,13 @@ async def compress_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
||||||
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
if len(edges) == 0:
|
||||||
|
return edges
|
||||||
|
|
||||||
|
anchor = edges[0]
|
||||||
|
edges.sort(key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or []))
|
||||||
|
|
||||||
|
edge_chunks = [edges[i: i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
||||||
|
|
||||||
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,27 +94,27 @@ async def dedupe_extracted_edges(
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# Create edge map
|
# Create edge map
|
||||||
edge_map = {}
|
edge_map = {}
|
||||||
for edge in existing_edges:
|
|
||||||
edge_map[edge.fact] = edge
|
|
||||||
for edge in extracted_edges:
|
for edge in extracted_edges:
|
||||||
if edge.fact in edge_map:
|
edge_map[edge.uuid] = edge
|
||||||
continue
|
|
||||||
edge_map[edge.fact] = edge
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
|
'extracted_edges': [
|
||||||
'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
|
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
|
||||||
|
],
|
||||||
|
'existing_edges': [
|
||||||
|
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
||||||
new_edges_data = llm_response.get('new_edges', [])
|
unique_edge_data = llm_response.get('unique_facts', [])
|
||||||
logger.info(f'Extracted new edges: {new_edges_data}')
|
logger.info(f'Extracted unique edges: {unique_edge_data}')
|
||||||
|
|
||||||
# Get full edge data
|
# Get full edge data
|
||||||
edges = []
|
edges = []
|
||||||
for edge_data in new_edges_data:
|
for unique_edge in unique_edge_data:
|
||||||
edge = edge_map[edge_data['fact']]
|
edge = edge_map[unique_edge['uuid']]
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
@ -129,15 +129,15 @@ async def dedupe_edge_list(
|
||||||
# Create edge map
|
# Create edge map
|
||||||
edge_map = {}
|
edge_map = {}
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
edge_map[edge.fact] = edge
|
edge_map[edge.uuid] = edge
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]}
|
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.dedupe_edges.edge_list(context)
|
prompt_library.dedupe_edges.edge_list(context)
|
||||||
)
|
)
|
||||||
unique_edges_data = llm_response.get('unique_edges', [])
|
unique_edges_data = llm_response.get('unique_facts', [])
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
||||||
|
|
@ -145,7 +145,9 @@ async def dedupe_edge_list(
|
||||||
# Get full edge data
|
# Get full edge data
|
||||||
unique_edges = []
|
unique_edges = []
|
||||||
for edge_data in unique_edges_data:
|
for edge_data in unique_edges_data:
|
||||||
fact = edge_data['fact']
|
uuid = edge_data['uuid']
|
||||||
unique_edges.append(edge_map[fact])
|
edge = edge_map[uuid]
|
||||||
|
edge.fact = edge_data['fact']
|
||||||
|
unique_edges.append(edge)
|
||||||
|
|
||||||
return unique_edges
|
return unique_edges
|
||||||
|
|
|
||||||
|
|
@ -62,10 +62,10 @@ async def main(use_bulk: bool = True):
|
||||||
episode_type='string',
|
episode_type='string',
|
||||||
reference_time=message.actual_timestamp,
|
reference_time=message.actual_timestamp,
|
||||||
)
|
)
|
||||||
for i, message in enumerate(messages[3:7])
|
for i, message in enumerate(messages[3:14])
|
||||||
]
|
]
|
||||||
|
|
||||||
await client.add_episode_bulk(episodes)
|
await client.add_episode_bulk(episodes)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main(True))
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue