add bulk temporal extraction and improve bulk quality and performance (#67)
* parallelize edge deduping more * parallelize node insertion more * improve bulk behavior performance * dedupe nodes actually works * add a reranker to search * bulk dedupe episodes only across the same nodes * add temporal extraction bulk function * cleaned up bulk * default to 4o * format * mypy * mympy * mypy ignore
This commit is contained in:
parent
aac06d9d24
commit
35a4e5172b
8 changed files with 203 additions and 61 deletions
|
|
@ -29,6 +29,7 @@ from graphiti_core.llm_client.utils import generate_embedding
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
|
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
get_relevant_nodes,
|
get_relevant_nodes,
|
||||||
hybrid_node_search,
|
hybrid_node_search,
|
||||||
|
|
@ -41,6 +42,7 @@ from graphiti_core.utils.bulk_utils import (
|
||||||
RawEpisode,
|
RawEpisode,
|
||||||
dedupe_edges_bulk,
|
dedupe_edges_bulk,
|
||||||
dedupe_nodes_bulk,
|
dedupe_nodes_bulk,
|
||||||
|
extract_edge_dates_bulk,
|
||||||
extract_nodes_and_edges_bulk,
|
extract_nodes_and_edges_bulk,
|
||||||
resolve_edge_pointers,
|
resolve_edge_pointers,
|
||||||
retrieve_previous_episodes_bulk,
|
retrieve_previous_episodes_bulk,
|
||||||
|
|
@ -319,26 +321,24 @@ class Graphiti:
|
||||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||||
self.llm_client,
|
self.llm_client,
|
||||||
edge,
|
edge,
|
||||||
episode.valid_at,
|
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
)
|
)
|
||||||
edge.valid_at = valid_at
|
edge.valid_at = valid_at
|
||||||
edge.invalid_at = invalid_at
|
edge.invalid_at = invalid_at
|
||||||
if edge.invalid_at:
|
if edge.invalid_at:
|
||||||
edge.expired_at = datetime.now()
|
edge.expired_at = now
|
||||||
for edge in existing_edges:
|
for edge in existing_edges:
|
||||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||||
self.llm_client,
|
self.llm_client,
|
||||||
edge,
|
edge,
|
||||||
episode.valid_at,
|
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
)
|
)
|
||||||
edge.valid_at = valid_at
|
edge.valid_at = valid_at
|
||||||
edge.invalid_at = invalid_at
|
edge.invalid_at = invalid_at
|
||||||
if edge.invalid_at:
|
if edge.invalid_at:
|
||||||
edge.expired_at = datetime.now()
|
edge.expired_at = now
|
||||||
(
|
(
|
||||||
old_edges_with_nodes_pending_invalidation,
|
old_edges_with_nodes_pending_invalidation,
|
||||||
new_edges_with_nodes,
|
new_edges_with_nodes,
|
||||||
|
|
@ -481,15 +481,18 @@ class Graphiti:
|
||||||
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dedupe extracted nodes
|
# Dedupe extracted nodes, compress extracted edges
|
||||||
nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes)
|
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
|
||||||
|
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
||||||
|
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
||||||
|
)
|
||||||
|
|
||||||
# save nodes to KG
|
# save nodes to KG
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||||
|
|
||||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||||
extracted_edges, uuid_map
|
extracted_edges_timestamped, uuid_map
|
||||||
)
|
)
|
||||||
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
||||||
episodic_edges, uuid_map
|
episodic_edges, uuid_map
|
||||||
|
|
@ -579,7 +582,9 @@ class Graphiti:
|
||||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_nodes_by_query(self, query: str, limit: int | None = None) -> list[EntityNode]:
|
async def get_nodes_by_query(
|
||||||
|
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
||||||
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve nodes from the graph database based on a text query.
|
Retrieve nodes from the graph database based on a text query.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,9 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
1. start with the list of nodes from New Nodes
|
1. start with the list of nodes from New Nodes
|
||||||
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
|
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
|
||||||
node in the list
|
node in the list
|
||||||
3. Respond with the resulting list of nodes
|
3. when deduplicating nodes, synthesize their summaries into a short new summary that contains the relevant information
|
||||||
|
of the summaries of the new and existing nodes
|
||||||
|
4. Respond with the resulting list of nodes
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||||
|
|
@ -64,6 +66,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
"new_nodes": [
|
"new_nodes": [
|
||||||
{{
|
{{
|
||||||
"name": "Unique identifier for the node",
|
"name": "Unique identifier for the node",
|
||||||
|
"summary": "Brief summary of the node's role or significance"
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
@ -92,6 +95,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
|
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
|
||||||
Task:
|
Task:
|
||||||
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
||||||
|
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
|
||||||
|
relevant information of the summaries of the new and existing nodes.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||||
|
|
@ -104,7 +109,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
"duplicates": [
|
"duplicates": [
|
||||||
{{
|
{{
|
||||||
"name": "name of the new node",
|
"name": "name of the new node",
|
||||||
"duplicate_of": "name of the existing node"
|
"duplicate_of": "name of the existing node",
|
||||||
|
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
@ -130,6 +136,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
Task:
|
Task:
|
||||||
1. Group nodes together such that all duplicate nodes are in the same list of names
|
1. Group nodes together such that all duplicate nodes are in the same list of names
|
||||||
2. All duplicate names should be grouped together in the same list
|
2. All duplicate names should be grouped together in the same list
|
||||||
|
3. Also return a new summary that synthesizes the summary into a new short summary
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Each name from the list of nodes should appear EXACTLY once in your response
|
1. Each name from the list of nodes should appear EXACTLY once in your response
|
||||||
|
|
@ -140,6 +147,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
"nodes": [
|
"nodes": [
|
||||||
{{
|
{{
|
||||||
"names": ["myNode", "node that is a duplicate of myNode"],
|
"names": ["myNode", "node that is a duplicate of myNode"],
|
||||||
|
"summary": "Brief summary of the node summaries that appear in the list of names."
|
||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
|
||||||
|
|
@ -110,10 +110,11 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Create edges only between the provided nodes.
|
1. Create edges only between the provided nodes.
|
||||||
2. Each edge should represent a clear relationship between two nodes.
|
2. Each edge should represent a clear relationship between two DISTINCT nodes.
|
||||||
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
||||||
4. Provide a more detailed fact describing the relationship.
|
4. Provide a more detailed fact describing the relationship.
|
||||||
5. Consider temporal aspects of relationships when relevant.
|
5. Consider temporal aspects of relationships when relevant.
|
||||||
|
6. Avoid using the same node as the source and target of a relationship
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
|
|
|
||||||
|
|
@ -63,12 +63,12 @@ class SearchResults(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_search(
|
async def hybrid_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder,
|
embedder,
|
||||||
query: str,
|
query: str,
|
||||||
timestamp: datetime,
|
timestamp: datetime,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -268,13 +268,13 @@ async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
limit: int | None = None,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search for nodes using both text queries and embeddings.
|
Perform a hybrid search for nodes using both text queries and embeddings.
|
||||||
|
|
||||||
This method combines fulltext search and vector similarity search to find
|
This method combines fulltext search and vector similarity search to find
|
||||||
relevant nodes in the graph database.
|
relevant nodes in the graph database. It uses an rrf reranker.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
@ -307,27 +307,25 @@ async def hybrid_node_search(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start = time()
|
start = time()
|
||||||
relevant_nodes: list[EntityNode] = []
|
|
||||||
relevant_node_uuids = set()
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results: list[list[EntityNode]] = list(
|
||||||
*[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
|
await asyncio.gather(
|
||||||
*[
|
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
|
||||||
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
|
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
|
||||||
for e in embeddings
|
)
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in results:
|
node_uuid_map: dict[str, EntityNode] = {
|
||||||
for node in result:
|
node.uuid: node for result in results for node in result
|
||||||
if node.uuid in relevant_node_uuids:
|
}
|
||||||
continue
|
result_uuids = [[node.uuid for node in result] for result in results]
|
||||||
|
|
||||||
relevant_node_uuids.add(node.uuid)
|
ranked_uuids = rrf(result_uuids)
|
||||||
relevant_nodes.append(node)
|
|
||||||
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
|
logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
|
||||||
return relevant_nodes
|
return relevant_nodes
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,14 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from numpy import dot
|
from numpy import dot, sqrt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||||
|
|
@ -39,8 +42,11 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||||
dedupe_node_list,
|
dedupe_node_list,
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
)
|
)
|
||||||
|
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
||||||
|
|
||||||
CHUNK_SIZE = 15
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CHUNK_SIZE = 10
|
||||||
|
|
||||||
|
|
||||||
class RawEpisode(BaseModel):
|
class RawEpisode(BaseModel):
|
||||||
|
|
@ -52,7 +58,7 @@ class RawEpisode(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_previous_episodes_bulk(
|
async def retrieve_previous_episodes_bulk(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||||
previous_episodes_list = await asyncio.gather(
|
previous_episodes_list = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -68,7 +74,7 @@ async def retrieve_previous_episodes_bulk(
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||||
extracted_nodes_bulk = await asyncio.gather(
|
extracted_nodes_bulk = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -105,36 +111,67 @@ async def extract_nodes_and_edges_bulk(
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_nodes_bulk(
|
async def dedupe_nodes_bulk(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
# Compress nodes
|
# Compress nodes
|
||||||
nodes, uuid_map = node_name_match(extracted_nodes)
|
nodes, uuid_map = node_name_match(extracted_nodes)
|
||||||
|
|
||||||
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
||||||
|
|
||||||
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
|
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||||
|
|
||||||
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
|
existing_nodes_chunks: list[list[EntityNode]] = list(
|
||||||
llm_client, compressed_nodes, existing_nodes
|
await asyncio.gather(
|
||||||
|
*[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
compressed_map.update(partial_uuid_map)
|
results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
||||||
|
for i, node_chunk in enumerate(node_chunks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return nodes, compressed_map
|
final_nodes: list[EntityNode] = []
|
||||||
|
for result in results:
|
||||||
|
final_nodes.extend(result[0])
|
||||||
|
partial_uuid_map = result[1]
|
||||||
|
compressed_map.update(partial_uuid_map)
|
||||||
|
|
||||||
|
return final_nodes, compressed_map
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_edges_bulk(
|
async def dedupe_edges_bulk(
|
||||||
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# Compress edges
|
# First compress edges
|
||||||
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
||||||
|
|
||||||
existing_edges = await get_relevant_edges(compressed_edges, driver)
|
edge_chunks = [
|
||||||
|
compressed_edges[i: i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
|
||||||
|
]
|
||||||
|
|
||||||
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
|
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
||||||
|
for i, edge_chunk in enumerate(edge_chunks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -152,15 +189,60 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
|
||||||
|
|
||||||
|
|
||||||
async def compress_nodes(
|
async def compress_nodes(
|
||||||
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
return nodes, uuid_map
|
return nodes, uuid_map
|
||||||
|
|
||||||
anchor = nodes[0]
|
# Our approach involves us deduplicating chunks of nodes in parallel.
|
||||||
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
|
# We want n chunks of size n so that n ** 2 == len(nodes).
|
||||||
|
# We want chunk sizes to be at least 10 for optimizing LLM processing time
|
||||||
|
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
|
||||||
|
|
||||||
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
# First calculate similarity scores between nodes
|
||||||
|
similarity_scores: list[tuple[int, int, float]] = [
|
||||||
|
(i, j, dot(n.name_embedding or [], m.name_embedding or []))
|
||||||
|
for i, n in enumerate(nodes)
|
||||||
|
for j, m in enumerate(nodes[:i])
|
||||||
|
]
|
||||||
|
|
||||||
|
# We now sort by semantic similarity
|
||||||
|
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
|
||||||
|
|
||||||
|
# initialize our chunks based on chunk size
|
||||||
|
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
|
||||||
|
|
||||||
|
# Draft the most similar nodes into the same chunk
|
||||||
|
while len(similarity_scores) > 0:
|
||||||
|
i, j, _ = similarity_scores.pop()
|
||||||
|
# determine if any of the nodes have already been drafted into a chunk
|
||||||
|
n = nodes[i]
|
||||||
|
m = nodes[j]
|
||||||
|
# make sure the shortest chunks get preference
|
||||||
|
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
|
||||||
|
|
||||||
|
n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
||||||
|
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
||||||
|
|
||||||
|
# both nodes already in a chunk
|
||||||
|
if n_chunk > -1 and m_chunk > -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# n has a chunk and that chunk is not full
|
||||||
|
elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
|
||||||
|
# put m in the same chunk as n
|
||||||
|
node_chunks[n_chunk].append(m)
|
||||||
|
|
||||||
|
# m has a chunk and that chunk is not full
|
||||||
|
elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
|
||||||
|
# put n in the same chunk as m
|
||||||
|
node_chunks[m_chunk].append(n)
|
||||||
|
|
||||||
|
# neither node has a chunk or the chunk is full
|
||||||
|
else:
|
||||||
|
# add both nodes to the shortest chunk
|
||||||
|
node_chunks[-1].extend([n, m])
|
||||||
|
|
||||||
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
||||||
|
|
||||||
|
|
@ -181,13 +263,21 @@ async def compress_nodes(
|
||||||
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
||||||
if len(edges) == 0:
|
if len(edges) == 0:
|
||||||
return edges
|
return edges
|
||||||
|
# We only want to dedupe edges that are between the same pair of nodes
|
||||||
|
# We build a map of the edges based on their source and target nodes.
|
||||||
|
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||||
|
for edge in edges:
|
||||||
|
# We drop loop edges
|
||||||
|
if edge.source_node_uuid == edge.target_node_uuid:
|
||||||
|
continue
|
||||||
|
|
||||||
anchor = edges[0]
|
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||||
edges.sort(
|
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||||
key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])
|
pointers.sort()
|
||||||
)
|
|
||||||
|
|
||||||
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||||
|
|
||||||
|
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||||
|
|
||||||
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
||||||
|
|
||||||
|
|
@ -225,3 +315,43 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
||||||
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_edge_dates_bulk(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
extracted_edges: list[EntityEdge],
|
||||||
|
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
edges: list[EntityEdge] = []
|
||||||
|
# confirm that all of our edges have at least one episode
|
||||||
|
for edge in extracted_edges:
|
||||||
|
if edge.episodes is not None and len(edge.episodes) > 0:
|
||||||
|
edges.append(edge)
|
||||||
|
|
||||||
|
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
|
||||||
|
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
||||||
|
}
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
extract_edge_dates(
|
||||||
|
llm_client,
|
||||||
|
edge,
|
||||||
|
episode_uuid_map[edge.episodes[0]][0], # type: ignore
|
||||||
|
episode_uuid_map[edge.episodes[0]][1], # type: ignore
|
||||||
|
)
|
||||||
|
for edge in edges
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
valid_at = result[0]
|
||||||
|
invalid_at = result[1]
|
||||||
|
edge = edges[i]
|
||||||
|
|
||||||
|
edge.valid_at = valid_at
|
||||||
|
edge.invalid_at = invalid_at
|
||||||
|
if edge.invalid_at:
|
||||||
|
edge.expired_at = datetime.now()
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
|
||||||
|
|
@ -189,6 +189,7 @@ async def dedupe_node_list(
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
for node_data in nodes_data:
|
for node_data in nodes_data:
|
||||||
node = node_map[node_data['names'][0]]
|
node = node_map[node_data['names'][0]]
|
||||||
|
node.summary = node_data['summary']
|
||||||
unique_nodes.append(node)
|
unique_nodes.append(node)
|
||||||
|
|
||||||
for name in node_data['names'][1:]:
|
for name in node_data['names'][1:]:
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,6 @@ def process_edge_invalidation_llm_response(
|
||||||
async def extract_edge_dates(
|
async def extract_edge_dates(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
edge: EntityEdge,
|
edge: EntityEdge,
|
||||||
reference_time: datetime,
|
|
||||||
current_episode: EpisodicNode,
|
current_episode: EpisodicNode,
|
||||||
previous_episodes: List[EpisodicNode],
|
previous_episodes: List[EpisodicNode],
|
||||||
) -> tuple[datetime | None, datetime | None, str]:
|
) -> tuple[datetime | None, datetime | None, str]:
|
||||||
|
|
@ -156,7 +155,7 @@ async def extract_edge_dates(
|
||||||
'edge_fact': edge.fact,
|
'edge_fact': edge.fact,
|
||||||
'current_episode': current_episode.content,
|
'current_episode': current_episode.content,
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
'reference_timestamp': reference_time.isoformat(),
|
'reference_timestamp': current_episode.valid_at.isoformat(),
|
||||||
}
|
}
|
||||||
llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))
|
llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue