Speed up add episode (#77)

* WIP

* updates

* use uuid for node dedupe

* pret-testing

* parallelized node resolution

* working add_episode

* revert to 4o

* format

* mypy update

* update types
This commit is contained in:
Preston Rasmussen 2024-09-03 13:25:52 -04:00 committed by GitHub
parent db12ac548d
commit e9e6039b1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 427 additions and 177 deletions

View file

@ -70,6 +70,7 @@ async def main(use_bulk: bool = True):
reference_time=message.actual_timestamp, reference_time=message.actual_timestamp,
source_description='Podcast Transcript', source_description='Podcast Transcript',
) )
return
episodes: list[RawEpisode] = [ episodes: list[RawEpisode] = [
RawEpisode( RawEpisode(
@ -79,10 +80,10 @@ async def main(use_bulk: bool = True):
source_description='Podcast Transcript', source_description='Podcast Transcript',
reference_time=message.actual_timestamp, reference_time=message.actual_timestamp,
) )
for i, message in enumerate(messages[3:14]) for i, message in enumerate(messages[3:20])
] ]
await client.add_episode_bulk(episodes) await client.add_episode_bulk(episodes)
asyncio.run(main(True)) asyncio.run(main(False))

View file

@ -48,14 +48,17 @@ from graphiti_core.utils.bulk_utils import (
retrieve_previous_episodes_bulk, retrieve_previous_episodes_bulk,
) )
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
dedupe_extracted_edges,
extract_edges, extract_edges,
resolve_extracted_edges,
) )
from graphiti_core.utils.maintenance.graph_data_operations import ( from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN, EPISODE_WINDOW_LEN,
build_indices_and_constraints, build_indices_and_constraints,
) )
from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import ( from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates, extract_edge_dates,
invalidate_edges, invalidate_edges,
@ -177,9 +180,9 @@ class Graphiti:
await build_indices_and_constraints(self.driver) await build_indices_and_constraints(self.driver)
async def retrieve_episodes( async def retrieve_episodes(
self, self,
reference_time: datetime, reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN, last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
""" """
Retrieve the last n episodic nodes from the graph. Retrieve the last n episodic nodes from the graph.
@ -207,14 +210,14 @@ class Graphiti:
return await retrieve_episodes(self.driver, reference_time, last_n) return await retrieve_episodes(self.driver, reference_time, last_n)
async def add_episode( async def add_episode(
self, self,
name: str, name: str,
episode_body: str, episode_body: str,
source_description: str, source_description: str,
reference_time: datetime, reference_time: datetime,
source: EpisodeType = EpisodeType.message, source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None, success_callback: Callable | None = None,
error_callback: Callable | None = None, error_callback: Callable | None = None,
): ):
""" """
Process an episode and update the graph. Process an episode and update the graph.
@ -265,7 +268,6 @@ class Graphiti:
nodes: list[EntityNode] = [] nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = [] entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.get_embedder() embedder = self.llm_client.get_embedder()
now = datetime.now() now = datetime.now()
@ -280,6 +282,8 @@ class Graphiti:
valid_at=reference_time, valid_at=reference_time,
) )
# Extract entities as nodes
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
@ -288,57 +292,82 @@ class Graphiti:
await asyncio.gather( await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes] *[node.generate_name_embedding(embedder) for node in extracted_nodes]
) )
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
# Resolve extracted nodes with nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
)
)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
nodes.extend(touched_nodes)
mentioned_nodes, _ = await resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)
# Extract facts as edges given entity nodes
extracted_edges = await extract_edges( extracted_edges = await extract_edges(
self.llm_client, episode, touched_nodes, previous_episodes self.llm_client, episode, mentioned_nodes, previous_episodes
) )
# calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
existing_edges = await get_relevant_edges(extracted_edges, self.driver) # Resolve extracted edges with edges already in the graph
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}') existing_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
[edge],
self.driver,
RELEVANT_SCHEMA_LIMIT,
edge.source_node_uuid,
edge.target_node_uuid,
)
for edge in extracted_edges
]
)
)
logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
)
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
deduped_edges = await dedupe_extracted_edges( deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, self.llm_client, extracted_edges, existing_edges_list
extracted_edges,
existing_edges,
) )
edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] # Extract dates for the newly extracted edges
for edge in deduped_edges: edge_dates = await asyncio.gather(
edge_touched_node_uuids.append(edge.source_node_uuid) *[
edge_touched_node_uuids.append(edge.target_node_uuid) extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
for edge in deduped_edges
]
)
for i, edge in enumerate(deduped_edges):
valid_at = edge_dates[i][0]
invalid_at = edge_dates[i][1]
for edge in deduped_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
edge.valid_at = valid_at edge.valid_at = valid_at
edge.invalid_at = invalid_at edge.invalid_at = invalid_at
if edge.invalid_at: if edge.invalid_at is not None:
edge.expired_at = now
for edge in existing_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = now edge.expired_at = now
entity_edges.extend(deduped_edges)
existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
]
( (
old_edges_with_nodes_pending_invalidation, old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes, new_edges_with_nodes,
@ -361,30 +390,18 @@ class Graphiti:
for deduped_edge in deduped_edges: for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid: if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at deduped_edge.expired_at = edge.expired_at
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
edges_to_save = existing_edges + deduped_edges entity_edges.extend(existing_edges)
entity_edges.extend(edges_to_save)
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
episodic_edges.extend( episodic_edges: list[EpisodicEdge] = build_episodic_edges(
build_episodic_edges( mentioned_nodes,
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them episode,
involved_nodes, now,
episode,
now,
)
) )
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
logger.info(f'Built episodic edges: {episodic_edges}') logger.info(f'Built episodic edges: {episodic_edges}')
# Future optimization would be using batch operations to save nodes and edges # Future optimization would be using batch operations to save nodes and edges
@ -395,9 +412,7 @@ class Graphiti:
end = time() end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms') logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
if success_callback: if success_callback:
await success_callback(episode) await success_callback(episode)
except Exception as e: except Exception as e:
@ -407,8 +422,8 @@ class Graphiti:
raise e raise e
async def add_episode_bulk( async def add_episode_bulk(
self, self,
bulk_episodes: list[RawEpisode], bulk_episodes: list[RawEpisode],
): ):
""" """
Process multiple episodes in bulk and update the graph. Process multiple episodes in bulk and update the graph.
@ -572,18 +587,18 @@ class Graphiti:
return edges return edges
async def _search( async def _search(
self, self,
query: str, query: str,
timestamp: datetime, timestamp: datetime,
config: SearchConfig, config: SearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
): ):
return await hybrid_search( return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
) )
async def get_nodes_by_query( async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
Retrieve nodes from the graph database based on a text query. Retrieve nodes from the graph database based on a text query.

View file

@ -23,12 +23,14 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
v2: PromptVersion v2: PromptVersion
v3: PromptVersion
edge_list: PromptVersion edge_list: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
v2: PromptFunction v2: PromptFunction
v3: PromptFunction
edge_list: PromptFunction edge_list: PromptFunction
@ -41,17 +43,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message( Message(
role='user', role='user',
content=f""" content=f"""
Given the following context, deduplicate facts from a list of new facts given a list of existing facts: Given the following context, deduplicate facts from a list of new facts given a list of existing edges:
Existing Facts: Existing Edges:
{json.dumps(context['existing_edges'], indent=2)} {json.dumps(context['existing_edges'], indent=2)}
New Facts: New Edges:
{json.dumps(context['extracted_edges'], indent=2)} {json.dumps(context['extracted_edges'], indent=2)}
Task: Task:
If any facts in New Facts is a duplicate of a fact in Existing Facts, If any edge in New Edges is a duplicate of an edge in Existing Edges, add their uuids to the output list.
do not return it in the list of unique facts. When finding duplicates edges, synthesize their facts into a short new fact.
Guidelines: Guidelines:
1. identical or near identical facts are duplicates 1. identical or near identical facts are duplicates
@ -60,9 +62,11 @@ def v1(context: dict[str, Any]) -> list[Message]:
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
"unique_facts": [ "duplicates": [
{{ {{
"uuid": "unique identifier of the fact" "uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
"duplicate_of": "uuid of the existing node",
"fact": "one sentence description of the fact"
}} }}
] ]
}} }}
@ -113,6 +117,40 @@ def v2(context: dict[str, Any]) -> list[Message]:
] ]
def v3(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Edge:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing edge in the response
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
}}
""",
),
]
def edge_list(context: dict[str, Any]) -> list[Message]: def edge_list(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
@ -151,4 +189,4 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
] ]
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list} versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'edge_list': edge_list}

View file

@ -23,13 +23,15 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
v2: PromptVersion v2: PromptVersion
v3: PromptVersion
node_list: PromptVersion node_list: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
v2: PromptFunction v2: PromptFunction
node_list: PromptVersion v3: PromptFunction
node_list: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]: def v1(context: dict[str, Any]) -> list[Message]:
@ -94,22 +96,22 @@ def v2(context: dict[str, Any]) -> list[Message]:
Important: Important:
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!! If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
Task: Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list If any node in New Nodes is a duplicate of a node in Existing Nodes, add their uuids to the output list
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
relevant information of the summaries of the new and existing nodes. relevant information of the summaries of the new and existing nodes.
Guidelines: Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates, 1. Use both the name and summary of nodes to determine if they are duplicates,
duplicate nodes may have different names duplicate nodes may have different names
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be 2. In the output, uuid should always be the uuid of the New Node that is a duplicate. duplicate_of should be
the name of the Existing Node. the uuid of the Existing Node.
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
"duplicates": [ "duplicates": [
{{ {{
"name": "name of the new node", "uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
"duplicate_of": "name of the existing node", "duplicate_of": "uuid of the existing node",
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes" "summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
}} }}
] ]
@ -119,6 +121,44 @@ def v2(context: dict[str, Any]) -> list[Message]:
] ]
def v3(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates nodes from node lists.',
),
Message(
role='user',
content=f"""
Given the following context, determine whether the New Node represents any of the entities in the list of Existing Nodes.
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
New Node:
{json.dumps(context['extracted_nodes'], indent=2)}
Task:
1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing node in the response
3. If is_duplicate is true, return a summary that synthesizes the information in the New Node summary and the
summary of the Existing Node it is a duplicate of.
Guidelines:
1. Use both the name and summary of nodes to determine if the entities are duplicates,
duplicate nodes may have different names
Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null",
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing node"
}}
""",
),
]
def node_list(context: dict[str, Any]) -> list[Message]: def node_list(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
@ -134,19 +174,19 @@ def node_list(context: dict[str, Any]) -> list[Message]:
{json.dumps(context['nodes'], indent=2)} {json.dumps(context['nodes'], indent=2)}
Task: Task:
1. Group nodes together such that all duplicate nodes are in the same list of names 1. Group nodes together such that all duplicate nodes are in the same list of uuids
2. All duplicate names should be grouped together in the same list 2. All duplicate uuids should be grouped together in the same list
3. Also return a new summary that synthesizes the summary into a new short summary 3. Also return a new summary that synthesizes the summary into a new short summary
Guidelines: Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response 1. Each uuid from the list of nodes should appear EXACTLY once in your response
2. If a node has no duplicates, it should appear in the response in a list of only one name 2. If a node has no duplicates, it should appear in the response in a list of only one uuid
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
"nodes": [ "nodes": [
{{ {{
"names": ["myNode", "node that is a duplicate of myNode"], "uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
"summary": "Brief summary of the node summaries that appear in the list of names." "summary": "Brief summary of the node summaries that appear in the list of names."
}} }}
] ]
@ -156,4 +196,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
] ]
versions: Versions = {'v1': v1, 'v2': v2, 'node_list': node_list} versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'node_list': node_list}

View file

@ -55,6 +55,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
1. Focus on entities, concepts, or actors that are central to the current episode. 1. Focus on entities, concepts, or actors that are central to the current episode.
2. Avoid creating nodes for relationships or actions (these will be handled as edges later). 2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
3. Provide a brief but informative summary for each node. 3. Provide a brief but informative summary for each node.
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{
@ -90,6 +91,7 @@ Guidelines:
3. Provide concise but informative summaries for each extracted node. 3. Provide concise but informative summaries for each extracted node.
4. Avoid creating nodes for relationships or actions. 4. Avoid creating nodes for relationships or actions.
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later). 5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{

View file

@ -83,7 +83,7 @@ async def hybrid_search(
nodes.extend(await get_mentioned_nodes(driver, episodes)) nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods: if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges) text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges)
search_results.append(text_search) search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods: if SearchMethod.cosine_similarity in config.search_methods:
@ -95,7 +95,7 @@ async def hybrid_search(
) )
similarity_search = await edge_similarity_search( similarity_search = await edge_similarity_search(
search_vector, driver, 2 * config.num_edges driver, search_vector, 2 * config.num_edges
) )
search_results.append(similarity_search) search_results.append(similarity_search)

View file

@ -96,14 +96,18 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search( async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT driver: AsyncDriver,
search_vector: list[float],
limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS r, score YIELD relationship AS rel, score
MATCH (n)-[r:RELATES_TO]->(m) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
@ -119,6 +123,8 @@ async def edge_similarity_search(
ORDER BY score DESC ORDER BY score DESC
""", """,
search_vector=search_vector, search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
limit=limit, limit=limit,
) )
@ -214,7 +220,11 @@ async def entity_fulltext_search(
async def edge_fulltext_search( async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT driver: AsyncDriver,
query: str,
limit=RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# fulltext search over facts # fulltext search over facts
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
@ -222,8 +232,8 @@ async def edge_fulltext_search(
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r]->(m:Entity) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
@ -239,6 +249,8 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit ORDER BY score DESC LIMIT $limit
""", """,
query=fuzzy_query, query=fuzzy_query,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
limit=limit, limit=limit,
) )
@ -369,6 +381,9 @@ async def get_relevant_nodes(
async def get_relevant_edges( async def get_relevant_edges(
edges: list[EntityEdge], edges: list[EntityEdge],
driver: AsyncDriver, driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()
relevant_edges: list[EntityEdge] = [] relevant_edges: list[EntityEdge] = []
@ -376,11 +391,16 @@ async def get_relevant_edges(
results = await asyncio.gather( results = await asyncio.gather(
*[ *[
edge_similarity_search(edge.fact_embedding, driver) edge_similarity_search(
driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid
)
for edge in edges for edge in edges
if edge.fact_embedding is not None if edge.fact_embedding is not None
], ],
*[edge_fulltext_search(edge.fact, driver) for edge in edges], *[
edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid)
for edge in edges
],
) )
for result in results: for result in results:

View file

@ -17,7 +17,6 @@ limitations under the License.
import asyncio import asyncio
import logging import logging
import typing import typing
from collections import defaultdict
from datetime import datetime from datetime import datetime
from math import ceil from math import ceil
@ -43,6 +42,7 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes, extract_nodes,
) )
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
from graphiti_core.utils.utils import chunk_edges_by_nodes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -128,7 +128,7 @@ async def dedupe_nodes_bulk(
) )
) )
results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list( results: list[tuple[list[EntityNode], dict[str, str]]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i]) dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
@ -265,19 +265,7 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
return edges return edges
# We only want to dedupe edges that are between the same pair of nodes # We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes. # We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) edge_chunks = chunk_edges_by_nodes(edges)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
@ -109,8 +110,8 @@ async def dedupe_extracted_edges(
existing_edges: list[EntityEdge], existing_edges: list[EntityEdge],
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# Create edge map # Create edge map
edge_map = {} edge_map: dict[str, EntityEdge] = {}
for edge in extracted_edges: for edge in existing_edges:
edge_map[edge.uuid] = edge edge_map[edge.uuid] = edge
# Prepare context for LLM # Prepare context for LLM
@ -124,18 +125,85 @@ async def dedupe_extracted_edges(
} }
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context)) llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
unique_edge_data = llm_response.get('unique_facts', []) duplicate_data = llm_response.get('duplicates', [])
logger.info(f'Extracted unique edges: {unique_edge_data}') logger.info(f'Extracted unique edges: {duplicate_data}')
duplicate_uuid_map: dict[str, str] = {}
for duplicate in duplicate_data:
uuid_value = duplicate['duplicate_of']
duplicate_uuid_map[duplicate['uuid']] = uuid_value
# Get full edge data # Get full edge data
edges = [] edges: list[EntityEdge] = []
for unique_edge in unique_edge_data: for edge in extracted_edges:
edge = edge_map[unique_edge['uuid']] if edge.uuid in duplicate_uuid_map:
edges.append(edge) existing_uuid = duplicate_uuid_map[edge.uuid]
existing_edge = edge_map[existing_uuid]
edges.append(existing_edge)
else:
edges.append(edge)
return edges return edges
async def resolve_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
existing_edges_lists: list[list[EntityEdge]],
) -> list[EntityEdge]:
resolved_edges: list[EntityEdge] = list(
await asyncio.gather(
*[
resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
]
)
)
return resolved_edges
async def resolve_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
) -> EntityEdge:
start = time()
# Prepare context for LLM
existing_edges_context = [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
]
extracted_edge_context = {
'uuid': extracted_edge.uuid,
'name': extracted_edge.name,
'fact': extracted_edge.fact,
}
context = {
'existing_edges': existing_edges_context,
'extracted_edges': extracted_edge_context,
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v3(context))
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)
edge = extracted_edge
if is_duplicate:
for existing_edge in existing_edges:
if existing_edge.uuid != uuid:
continue
edge = existing_edge
end = time()
logger.info(
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
)
return edge
async def dedupe_edge_list( async def dedupe_edge_list(
llm_client: LLMClient, llm_client: LLMClient,
edges: list[EntityEdge], edges: list[EntityEdge],

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
async def extract_message_nodes( async def extract_message_nodes(
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode] llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# Prepare context for LLM # Prepare context for LLM
context = { context = {
@ -48,8 +49,8 @@ async def extract_message_nodes(
async def extract_json_nodes( async def extract_json_nodes(
llm_client: LLMClient, llm_client: LLMClient,
episode: EpisodicNode, episode: EpisodicNode,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# Prepare context for LLM # Prepare context for LLM
context = { context = {
@ -66,9 +67,9 @@ async def extract_json_nodes(
async def extract_nodes( async def extract_nodes(
llm_client: LLMClient, llm_client: LLMClient,
episode: EpisodicNode, episode: EpisodicNode,
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
) -> list[EntityNode]: ) -> list[EntityNode]:
start = time() start = time()
extracted_node_data: list[dict[str, Any]] = [] extracted_node_data: list[dict[str, Any]] = []
@ -95,29 +96,24 @@ async def extract_nodes(
async def dedupe_extracted_nodes( async def dedupe_extracted_nodes(
llm_client: LLMClient, llm_client: LLMClient,
extracted_nodes: list[EntityNode], extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode], existing_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]: ) -> tuple[list[EntityNode], dict[str, str]]:
start = time() start = time()
# build existing node map # build existing node map
node_map: dict[str, EntityNode] = {} node_map: dict[str, EntityNode] = {}
for node in existing_nodes: for node in existing_nodes:
node_map[node.name] = node node_map[node.uuid] = node
# Temp hack
new_nodes_map: dict[str, EntityNode] = {}
for node in extracted_nodes:
new_nodes_map[node.name] = node
# Prepare context for LLM # Prepare context for LLM
existing_nodes_context = [ existing_nodes_context = [
{'name': node.name, 'summary': node.summary} for node in existing_nodes {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
] ]
extracted_nodes_context = [ extracted_nodes_context = [
{'name': node.name, 'summary': node.summary} for node in extracted_nodes {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
] ]
context = { context = {
@ -134,42 +130,104 @@ async def dedupe_extracted_nodes(
uuid_map: dict[str, str] = {} uuid_map: dict[str, str] = {}
for duplicate in duplicate_data: for duplicate in duplicate_data:
uuid = new_nodes_map[duplicate['name']].uuid uuid_value = duplicate['duplicate_of']
uuid_value = node_map[duplicate['duplicate_of']].uuid uuid_map[duplicate['uuid']] = uuid_value
uuid_map[uuid] = uuid_value
nodes: list[EntityNode] = [] nodes: list[EntityNode] = []
brand_new_nodes: list[EntityNode] = []
for node in extracted_nodes: for node in extracted_nodes:
if node.uuid in uuid_map: if node.uuid in uuid_map:
existing_uuid = uuid_map[node.uuid] existing_uuid = uuid_map[node.uuid]
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, existing_node = node_map[existing_uuid]
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? nodes.append(existing_node)
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) else:
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None) nodes.append(node)
if existing_node:
nodes.append(existing_node)
continue return nodes, uuid_map
brand_new_nodes.append(node)
nodes.append(node)
return nodes, uuid_map, brand_new_nodes
async def resolve_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes_lists: list[list[EntityNode]],
) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
results: list[tuple[EntityNode, dict[str, str]]] = list(
await asyncio.gather(
*[
resolve_extracted_node(llm_client, extracted_node, existing_nodes)
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
]
)
)
for result in results:
uuid_map.update(result[1])
resolved_nodes.append(result[0])
return resolved_nodes, uuid_map
async def resolve_extracted_node(
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
) -> tuple[EntityNode, dict[str, str]]:
start = time()
# Prepare context for LLM
existing_nodes_context = [
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
]
extracted_node_context = {
'uuid': extracted_node.uuid,
'name': extracted_node.name,
'summary': extracted_node.summary,
}
context = {
'existing_nodes': existing_nodes_context,
'extracted_nodes': extracted_node_context,
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v3(context))
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)
summary = llm_response.get('summary', '')
node = extracted_node
uuid_map: dict[str, str] = {}
if is_duplicate:
for existing_node in existing_nodes:
if existing_node.uuid != uuid:
continue
node = existing_node
node.summary = summary
uuid_map[extracted_node.uuid] = existing_node.uuid
end = time()
logger.info(
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
)
return node, uuid_map
async def dedupe_node_list( async def dedupe_node_list(
llm_client: LLMClient, llm_client: LLMClient,
nodes: list[EntityNode], nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]: ) -> tuple[list[EntityNode], dict[str, str]]:
start = time() start = time()
# build node map # build node map
node_map = {} node_map = {}
for node in nodes: for node in nodes:
node_map[node.name] = node node_map[node.uuid] = node
# Prepare context for LLM # Prepare context for LLM
nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes] nodes_context = [
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
]
context = { context = {
'nodes': nodes_context, 'nodes': nodes_context,
@ -188,13 +246,12 @@ async def dedupe_node_list(
unique_nodes = [] unique_nodes = []
uuid_map: dict[str, str] = {} uuid_map: dict[str, str] = {}
for node_data in nodes_data: for node_data in nodes_data:
node = node_map[node_data['names'][0]] node = node_map[node_data['uuids'][0]]
node.summary = node_data['summary'] node.summary = node_data['summary']
unique_nodes.append(node) unique_nodes.append(node)
for name in node_data['names'][1:]: for uuid in node_data['uuids'][1:]:
uuid = node_map[name].uuid uuid_value = node_map[node_data['uuids'][0]].uuid
uuid_value = node_map[node_data['names'][0]].uuid
uuid_map[uuid] = uuid_value uuid_map[uuid] = uuid_value
return unique_nodes, uuid_map return unique_nodes, uuid_map

View file

@ -149,7 +149,7 @@ async def extract_edge_dates(
edge: EntityEdge, edge: EntityEdge,
current_episode: EpisodicNode, current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode], previous_episodes: List[EpisodicNode],
) -> tuple[datetime | None, datetime | None, str]: ) -> tuple[datetime | None, datetime | None]:
context = { context = {
'edge_name': edge.name, 'edge_name': edge.name,
'edge_fact': edge.fact, 'edge_fact': edge.fact,
@ -180,4 +180,4 @@ async def extract_edge_dates(
logger.info(f'Edge date extraction explanation: {explanation}') logger.info(f'Edge date extraction explanation: {explanation}')
return valid_at_datetime, invalid_at_datetime, explanation return valid_at_datetime, invalid_at_datetime

View file

@ -15,8 +15,9 @@ limitations under the License.
""" """
import logging import logging
from collections import defaultdict
from graphiti_core.edges import EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,3 +38,23 @@ def build_episodic_edges(
) )
return edges return edges
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
return edge_chunks