Speed up add episode (#77)

* WIP

* updates

* use uuid for node dedupe

* pret-testing

* parallelized node resolution

* working add_episode

* revert to 4o

* format

* mypy update

* update types
This commit is contained in:
Preston Rasmussen 2024-09-03 13:25:52 -04:00 committed by GitHub
parent db12ac548d
commit e9e6039b1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 427 additions and 177 deletions

View file

@ -70,6 +70,7 @@ async def main(use_bulk: bool = True):
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
)
return
episodes: list[RawEpisode] = [
RawEpisode(
@ -79,10 +80,10 @@ async def main(use_bulk: bool = True):
source_description='Podcast Transcript',
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:14])
for i, message in enumerate(messages[3:20])
]
await client.add_episode_bulk(episodes)
asyncio.run(main(True))
asyncio.run(main(False))

View file

@ -48,14 +48,17 @@ from graphiti_core.utils.bulk_utils import (
retrieve_previous_episodes_bulk,
)
from graphiti_core.utils.maintenance.edge_operations import (
dedupe_extracted_edges,
extract_edges,
resolve_extracted_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
)
from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
invalidate_edges,
@ -177,9 +180,9 @@ class Graphiti:
await build_indices_and_constraints(self.driver)
async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@ -207,14 +210,14 @@ class Graphiti:
return await retrieve_episodes(self.driver, reference_time, last_n)
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""
Process an episode and update the graph.
@ -265,7 +268,6 @@ class Graphiti:
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.get_embedder()
now = datetime.now()
@ -280,6 +282,8 @@ class Graphiti:
valid_at=reference_time,
)
# Extract entities as nodes
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
@ -288,57 +292,82 @@ class Graphiti:
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
# Resolve extracted nodes with nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
)
)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
nodes.extend(touched_nodes)
mentioned_nodes, _ = await resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)
# Extract facts as edges given entity nodes
extracted_edges = await extract_edges(
self.llm_client, episode, touched_nodes, previous_episodes
self.llm_client, episode, mentioned_nodes, previous_episodes
)
# calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
# Resolve extracted edges with edges already in the graph
existing_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
[edge],
self.driver,
RELEVANT_SCHEMA_LIMIT,
edge.source_node_uuid,
edge.target_node_uuid,
)
for edge in extracted_edges
]
)
)
logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
)
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
deduped_edges = await dedupe_extracted_edges(
self.llm_client,
extracted_edges,
existing_edges,
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, extracted_edges, existing_edges_list
)
edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
for edge in deduped_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
# Extract dates for the newly extracted edges
edge_dates = await asyncio.gather(
*[
extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
for edge in deduped_edges
]
)
for i, edge in enumerate(deduped_edges):
valid_at = edge_dates[i][0]
invalid_at = edge_dates[i][1]
for edge in deduped_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = now
for edge in existing_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
if edge.invalid_at is not None:
edge.expired_at = now
entity_edges.extend(deduped_edges)
existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
]
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
@ -361,30 +390,18 @@ class Graphiti:
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
edges_to_save = existing_edges + deduped_edges
entity_edges.extend(edges_to_save)
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
entity_edges.extend(existing_edges)
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
involved_nodes,
episode,
now,
)
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
episode,
now,
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
logger.info(f'Built episodic edges: {episodic_edges}')
# Future optimization would be using batch operations to save nodes and edges
@ -395,9 +412,7 @@ class Graphiti:
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
@ -407,8 +422,8 @@ class Graphiti:
raise e
async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
self,
bulk_episodes: list[RawEpisode],
):
"""
Process multiple episodes in bulk and update the graph.
@ -572,18 +587,18 @@ class Graphiti:
return edges
async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)
async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.

View file

@ -23,12 +23,14 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
v3: PromptVersion
edge_list: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
v3: PromptFunction
edge_list: PromptFunction
@ -41,17 +43,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Given the following context, deduplicate facts from a list of new facts given a list of existing edges:
Existing Facts:
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Facts:
New Edges:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
If any facts in New Facts is a duplicate of a fact in Existing Facts,
do not return it in the list of unique facts.
If any edge in New Edges is a duplicate of an edge in Existing Edges, add their uuids to the output list.
When finding duplicates edges, synthesize their facts into a short new fact.
Guidelines:
1. identical or near identical facts are duplicates
@ -60,9 +62,11 @@ def v1(context: dict[str, Any]) -> list[Message]:
Respond with a JSON object in the following format:
{{
"unique_facts": [
"duplicates": [
{{
"uuid": "unique identifier of the fact"
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
"duplicate_of": "uuid of the existing node",
"fact": "one sentence description of the fact"
}}
]
}}
@ -113,6 +117,40 @@ def v2(context: dict[str, Any]) -> list[Message]:
]
def v3(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Edge:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing edge in the response
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
}}
""",
),
]
def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
@ -151,4 +189,4 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
]
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'edge_list': edge_list}

View file

@ -23,13 +23,15 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
v3: PromptVersion
node_list: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
node_list: PromptVersion
v3: PromptFunction
node_list: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]:
@ -94,22 +96,22 @@ def v2(context: dict[str, Any]) -> list[Message]:
Important:
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their uuids to the output list
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
relevant information of the summaries of the new and existing nodes.
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
duplicate nodes may have different names
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be
the name of the Existing Node.
2. In the output, uuid should always be the uuid of the New Node that is a duplicate. duplicate_of should be
the uuid of the Existing Node.
Respond with a JSON object in the following format:
{{
"duplicates": [
{{
"name": "name of the new node",
"duplicate_of": "name of the existing node",
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
"duplicate_of": "uuid of the existing node",
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
}}
]
@ -119,6 +121,44 @@ def v2(context: dict[str, Any]) -> list[Message]:
]
def v3(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates nodes from node lists.',
),
Message(
role='user',
content=f"""
Given the following context, determine whether the New Node represents any of the entities in the list of Existing Nodes.
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
New Node:
{json.dumps(context['extracted_nodes'], indent=2)}
Task:
1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing node in the response
3. If is_duplicate is true, return a summary that synthesizes the information in the New Node summary and the
summary of the Existing Node it is a duplicate of.
Guidelines:
1. Use both the name and summary of nodes to determine if the entities are duplicates,
duplicate nodes may have different names
Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null",
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing node"
}}
""",
),
]
def node_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
@ -134,19 +174,19 @@ def node_list(context: dict[str, Any]) -> list[Message]:
{json.dumps(context['nodes'], indent=2)}
Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All duplicate names should be grouped together in the same list
1. Group nodes together such that all duplicate nodes are in the same list of uuids
2. All duplicate uuids should be grouped together in the same list
3. Also return a new summary that synthesizes the summary into a new short summary
Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response
2. If a node has no duplicates, it should appear in the response in a list of only one name
1. Each uuid from the list of nodes should appear EXACTLY once in your response
2. If a node has no duplicates, it should appear in the response in a list of only one uuid
Respond with a JSON object in the following format:
{{
"nodes": [
{{
"names": ["myNode", "node that is a duplicate of myNode"],
"uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
"summary": "Brief summary of the node summaries that appear in the list of names."
}}
]
@ -156,4 +196,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
]
versions: Versions = {'v1': v1, 'v2': v2, 'node_list': node_list}
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'node_list': node_list}

View file

@ -55,6 +55,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
1. Focus on entities, concepts, or actors that are central to the current episode.
2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
3. Provide a brief but informative summary for each node.
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
Respond with a JSON object in the following format:
{{
@ -90,6 +91,7 @@ Guidelines:
3. Provide concise but informative summaries for each extracted node.
4. Avoid creating nodes for relationships or actions.
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
Respond with a JSON object in the following format:
{{

View file

@ -83,7 +83,7 @@ async def hybrid_search(
nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges)
text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges)
search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods:
@ -95,7 +95,7 @@ async def hybrid_search(
)
similarity_search = await edge_similarity_search(
search_vector, driver, 2 * config.num_edges
driver, search_vector, 2 * config.num_edges
)
search_results.append(similarity_search)

View file

@ -96,14 +96,18 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
driver: AsyncDriver,
search_vector: list[float],
limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
@ -119,6 +123,8 @@ async def edge_similarity_search(
ORDER BY score DESC
""",
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
limit=limit,
)
@ -214,7 +220,11 @@ async def entity_fulltext_search(
async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
driver: AsyncDriver,
query: str,
limit=RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
@ -222,8 +232,8 @@ async def edge_fulltext_search(
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
@ -239,6 +249,8 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit
""",
query=fuzzy_query,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
limit=limit,
)
@ -369,6 +381,9 @@ async def get_relevant_nodes(
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
@ -376,11 +391,16 @@ async def get_relevant_edges(
results = await asyncio.gather(
*[
edge_similarity_search(edge.fact_embedding, driver)
edge_similarity_search(
driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid
)
for edge in edges
if edge.fact_embedding is not None
],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
*[
edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid)
for edge in edges
],
)
for result in results:

View file

@ -17,7 +17,6 @@ limitations under the License.
import asyncio
import logging
import typing
from collections import defaultdict
from datetime import datetime
from math import ceil
@ -43,6 +42,7 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
from graphiti_core.utils.utils import chunk_edges_by_nodes
logger = logging.getLogger(__name__)
@ -128,7 +128,7 @@ async def dedupe_nodes_bulk(
)
)
results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list(
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
await asyncio.gather(
*[
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
@ -265,19 +265,7 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
return edges
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
edge_chunks = chunk_edges_by_nodes(edges)
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from datetime import datetime
from time import time
@ -109,8 +110,8 @@ async def dedupe_extracted_edges(
existing_edges: list[EntityEdge],
) -> list[EntityEdge]:
# Create edge map
edge_map = {}
for edge in extracted_edges:
edge_map: dict[str, EntityEdge] = {}
for edge in existing_edges:
edge_map[edge.uuid] = edge
# Prepare context for LLM
@ -124,18 +125,85 @@ async def dedupe_extracted_edges(
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
unique_edge_data = llm_response.get('unique_facts', [])
logger.info(f'Extracted unique edges: {unique_edge_data}')
duplicate_data = llm_response.get('duplicates', [])
logger.info(f'Extracted unique edges: {duplicate_data}')
duplicate_uuid_map: dict[str, str] = {}
for duplicate in duplicate_data:
uuid_value = duplicate['duplicate_of']
duplicate_uuid_map[duplicate['uuid']] = uuid_value
# Get full edge data
edges = []
for unique_edge in unique_edge_data:
edge = edge_map[unique_edge['uuid']]
edges.append(edge)
edges: list[EntityEdge] = []
for edge in extracted_edges:
if edge.uuid in duplicate_uuid_map:
existing_uuid = duplicate_uuid_map[edge.uuid]
existing_edge = edge_map[existing_uuid]
edges.append(existing_edge)
else:
edges.append(edge)
return edges
async def resolve_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
existing_edges_lists: list[list[EntityEdge]],
) -> list[EntityEdge]:
resolved_edges: list[EntityEdge] = list(
await asyncio.gather(
*[
resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
]
)
)
return resolved_edges
async def resolve_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
) -> EntityEdge:
start = time()
# Prepare context for LLM
existing_edges_context = [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
]
extracted_edge_context = {
'uuid': extracted_edge.uuid,
'name': extracted_edge.name,
'fact': extracted_edge.fact,
}
context = {
'existing_edges': existing_edges_context,
'extracted_edges': extracted_edge_context,
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v3(context))
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)
edge = extracted_edge
if is_duplicate:
for existing_edge in existing_edges:
if existing_edge.uuid != uuid:
continue
edge = existing_edge
end = time()
logger.info(
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
)
return edge
async def dedupe_edge_list(
llm_client: LLMClient,
edges: list[EntityEdge],

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from datetime import datetime
from time import time
@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
async def extract_message_nodes(
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
@ -48,8 +49,8 @@ async def extract_message_nodes(
async def extract_json_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
llm_client: LLMClient,
episode: EpisodicNode,
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
@ -66,9 +67,9 @@ async def extract_json_nodes(
async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
start = time()
extracted_node_data: list[dict[str, Any]] = []
@ -95,29 +96,24 @@ async def extract_nodes(
async def dedupe_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
# build existing node map
node_map: dict[str, EntityNode] = {}
for node in existing_nodes:
node_map[node.name] = node
# Temp hack
new_nodes_map: dict[str, EntityNode] = {}
for node in extracted_nodes:
new_nodes_map[node.name] = node
node_map[node.uuid] = node
# Prepare context for LLM
existing_nodes_context = [
{'name': node.name, 'summary': node.summary} for node in existing_nodes
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
]
extracted_nodes_context = [
{'name': node.name, 'summary': node.summary} for node in extracted_nodes
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
]
context = {
@ -134,42 +130,104 @@ async def dedupe_extracted_nodes(
uuid_map: dict[str, str] = {}
for duplicate in duplicate_data:
uuid = new_nodes_map[duplicate['name']].uuid
uuid_value = node_map[duplicate['duplicate_of']].uuid
uuid_map[uuid] = uuid_value
uuid_value = duplicate['duplicate_of']
uuid_map[duplicate['uuid']] = uuid_value
nodes: list[EntityNode] = []
brand_new_nodes: list[EntityNode] = []
for node in extracted_nodes:
if node.uuid in uuid_map:
existing_uuid = uuid_map[node.uuid]
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
if existing_node:
nodes.append(existing_node)
existing_node = node_map[existing_uuid]
nodes.append(existing_node)
else:
nodes.append(node)
continue
brand_new_nodes.append(node)
nodes.append(node)
return nodes, uuid_map
return nodes, uuid_map, brand_new_nodes
async def resolve_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes_lists: list[list[EntityNode]],
) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
results: list[tuple[EntityNode, dict[str, str]]] = list(
await asyncio.gather(
*[
resolve_extracted_node(llm_client, extracted_node, existing_nodes)
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
]
)
)
for result in results:
uuid_map.update(result[1])
resolved_nodes.append(result[0])
return resolved_nodes, uuid_map
async def resolve_extracted_node(
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
) -> tuple[EntityNode, dict[str, str]]:
start = time()
# Prepare context for LLM
existing_nodes_context = [
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
]
extracted_node_context = {
'uuid': extracted_node.uuid,
'name': extracted_node.name,
'summary': extracted_node.summary,
}
context = {
'existing_nodes': existing_nodes_context,
'extracted_nodes': extracted_node_context,
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v3(context))
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)
summary = llm_response.get('summary', '')
node = extracted_node
uuid_map: dict[str, str] = {}
if is_duplicate:
for existing_node in existing_nodes:
if existing_node.uuid != uuid:
continue
node = existing_node
node.summary = summary
uuid_map[extracted_node.uuid] = existing_node.uuid
end = time()
logger.info(
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
)
return node, uuid_map
async def dedupe_node_list(
llm_client: LLMClient,
nodes: list[EntityNode],
llm_client: LLMClient,
nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
# build node map
node_map = {}
for node in nodes:
node_map[node.name] = node
node_map[node.uuid] = node
# Prepare context for LLM
nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes]
nodes_context = [
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
]
context = {
'nodes': nodes_context,
@ -188,13 +246,12 @@ async def dedupe_node_list(
unique_nodes = []
uuid_map: dict[str, str] = {}
for node_data in nodes_data:
node = node_map[node_data['names'][0]]
node = node_map[node_data['uuids'][0]]
node.summary = node_data['summary']
unique_nodes.append(node)
for name in node_data['names'][1:]:
uuid = node_map[name].uuid
uuid_value = node_map[node_data['names'][0]].uuid
for uuid in node_data['uuids'][1:]:
uuid_value = node_map[node_data['uuids'][0]].uuid
uuid_map[uuid] = uuid_value
return unique_nodes, uuid_map

View file

@ -149,7 +149,7 @@ async def extract_edge_dates(
edge: EntityEdge,
current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode],
) -> tuple[datetime | None, datetime | None, str]:
) -> tuple[datetime | None, datetime | None]:
context = {
'edge_name': edge.name,
'edge_fact': edge.fact,
@ -180,4 +180,4 @@ async def extract_edge_dates(
logger.info(f'Edge date extraction explanation: {explanation}')
return valid_at_datetime, invalid_at_datetime, explanation
return valid_at_datetime, invalid_at_datetime

View file

@ -15,8 +15,9 @@ limitations under the License.
"""
import logging
from collections import defaultdict
from graphiti_core.edges import EpisodicEdge
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__)
@ -37,3 +38,23 @@ def build_episodic_edges(
)
return edges
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
return edge_chunks