diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index cb314427..f100926e 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -70,6 +70,7 @@ async def main(use_bulk: bool = True): reference_time=message.actual_timestamp, source_description='Podcast Transcript', ) + return episodes: list[RawEpisode] = [ RawEpisode( @@ -79,10 +80,10 @@ async def main(use_bulk: bool = True): source_description='Podcast Transcript', reference_time=message.actual_timestamp, ) - for i, message in enumerate(messages[3:14]) + for i, message in enumerate(messages[3:20]) ] await client.add_episode_bulk(episodes) -asyncio.run(main(True)) +asyncio.run(main(False)) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 8c8cccbb..2a23d79c 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -48,14 +48,17 @@ from graphiti_core.utils.bulk_utils import ( retrieve_previous_episodes_bulk, ) from graphiti_core.utils.maintenance.edge_operations import ( - dedupe_extracted_edges, extract_edges, + resolve_extracted_edges, ) from graphiti_core.utils.maintenance.graph_data_operations import ( EPISODE_WINDOW_LEN, build_indices_and_constraints, ) -from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes +from graphiti_core.utils.maintenance.node_operations import ( + extract_nodes, + resolve_extracted_nodes, +) from graphiti_core.utils.maintenance.temporal_operations import ( extract_edge_dates, invalidate_edges, @@ -177,9 +180,9 @@ class Graphiti: await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -207,14 +210,14 @@ class Graphiti: return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + success_callback: Callable | None = None, + error_callback: Callable | None = None, ): """ Process an episode and update the graph. @@ -265,7 +268,6 @@ class Graphiti: nodes: list[EntityNode] = [] entity_edges: list[EntityEdge] = [] - episodic_edges: list[EpisodicEdge] = [] embedder = self.llm_client.get_embedder() now = datetime.now() @@ -280,6 +282,8 @@ class Graphiti: valid_at=reference_time, ) + # Extract entities as nodes + extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') @@ -288,57 +292,82 @@ class Graphiti: await asyncio.gather( *[node.generate_name_embedding(embedder) for node in extracted_nodes] ) - existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver) + + # Resolve extracted nodes with nodes already in the graph + existing_nodes_lists: list[list[EntityNode]] = list( + await asyncio.gather( + *[get_relevant_nodes([node], self.driver) for node in extracted_nodes] + ) + ) + logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') - touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes( - self.llm_client, extracted_nodes, existing_nodes - ) - logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}') - nodes.extend(touched_nodes) + mentioned_nodes, _ = await resolve_extracted_nodes( + self.llm_client, extracted_nodes, existing_nodes_lists + ) + logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}') + nodes.extend(mentioned_nodes) + + # Extract facts as edges given entity nodes extracted_edges = await extract_edges( - self.llm_client, episode, touched_nodes, previous_episodes + self.llm_client, episode, mentioned_nodes, previous_episodes ) + # calculate embeddings await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) - existing_edges = await get_relevant_edges(extracted_edges, self.driver) - logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}') + # Resolve extracted edges with edges already in the graph + existing_edges_list: list[list[EntityEdge]] = list( + await asyncio.gather( + *[ + get_relevant_edges( + [edge], + self.driver, + RELEVANT_SCHEMA_LIMIT, + edge.source_node_uuid, + edge.target_node_uuid, + ) + for edge in extracted_edges + ] + ) + ) + logger.info( + f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}' + ) logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') - deduped_edges = await dedupe_extracted_edges( - self.llm_client, - extracted_edges, - existing_edges, + deduped_edges: list[EntityEdge] = await resolve_extracted_edges( + self.llm_client, extracted_edges, existing_edges_list ) - edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] - for edge in deduped_edges: - edge_touched_node_uuids.append(edge.source_node_uuid) - edge_touched_node_uuids.append(edge.target_node_uuid) + # Extract dates for the newly extracted edges + edge_dates = await asyncio.gather( + *[ + extract_edge_dates( + self.llm_client, + edge, + episode, + previous_episodes, + ) + for edge in deduped_edges + ] + ) + + for i, edge in enumerate(deduped_edges): + valid_at = edge_dates[i][0] + invalid_at = edge_dates[i][1] - for edge in deduped_edges: - valid_at, invalid_at, _ = await extract_edge_dates( - self.llm_client, - edge, - episode, - previous_episodes, - ) edge.valid_at = valid_at edge.invalid_at = invalid_at - if edge.invalid_at: - edge.expired_at = now - for edge in existing_edges: - valid_at, invalid_at, _ = await extract_edge_dates( - self.llm_client, - edge, - episode, - previous_episodes, - ) - edge.valid_at = valid_at - edge.invalid_at = invalid_at - if edge.invalid_at: + if edge.invalid_at is not None: edge.expired_at = now + + entity_edges.extend(deduped_edges) + + existing_edges: list[EntityEdge] = [ + e for edge_lst in existing_edges_list for e in edge_lst + ] + ( old_edges_with_nodes_pending_invalidation, new_edges_with_nodes, @@ -361,30 +390,18 @@ class Graphiti: for deduped_edge in deduped_edges: if deduped_edge.uuid == edge.uuid: deduped_edge.expired_at = edge.expired_at - edge_touched_node_uuids.append(edge.source_node_uuid) - edge_touched_node_uuids.append(edge.target_node_uuid) logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') - edges_to_save = existing_edges + deduped_edges - - entity_edges.extend(edges_to_save) - - edge_touched_node_uuids = list(set(edge_touched_node_uuids)) - involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids] - - logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}') + entity_edges.extend(existing_edges) logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') - episodic_edges.extend( - build_episodic_edges( - # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them - involved_nodes, - episode, - now, - ) + episodic_edges: list[EpisodicEdge] = build_episodic_edges( + mentioned_nodes, + episode, + now, ) - # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built + logger.info(f'Built episodic edges: {episodic_edges}') # Future optimization would be using batch operations to save nodes and edges @@ -395,9 +412,7 @@ class Graphiti: end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') - # for node in nodes: - # if isinstance(node, EntityNode): - # await node.update_summary(self.driver) + if success_callback: await success_callback(episode) except Exception as e: @@ -407,8 +422,8 @@ class Graphiti: raise e async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], + self, + bulk_episodes: list[RawEpisode], ): """ Process multiple episodes in bulk and update the graph. @@ -572,18 +587,18 @@ class Graphiti: return edges async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid ) async def get_nodes_by_query( - self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT + self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. diff --git a/graphiti_core/prompts/dedupe_edges.py b/graphiti_core/prompts/dedupe_edges.py index 7fd774d8..9902bf75 100644 --- a/graphiti_core/prompts/dedupe_edges.py +++ b/graphiti_core/prompts/dedupe_edges.py @@ -23,12 +23,14 @@ from .models import Message, PromptFunction, PromptVersion class Prompt(Protocol): v1: PromptVersion v2: PromptVersion + v3: PromptVersion edge_list: PromptVersion class Versions(TypedDict): v1: PromptFunction v2: PromptFunction + v3: PromptFunction edge_list: PromptFunction @@ -41,17 +43,17 @@ def v1(context: dict[str, Any]) -> list[Message]: Message( role='user', content=f""" - Given the following context, deduplicate facts from a list of new facts given a list of existing facts: + Given the following context, deduplicate facts from a list of new facts given a list of existing edges: - Existing Facts: + Existing Edges: {json.dumps(context['existing_edges'], indent=2)} - New Facts: + New Edges: {json.dumps(context['extracted_edges'], indent=2)} Task: - If any facts in New Facts is a duplicate of a fact in Existing Facts, - do not return it in the list of unique facts. + If any edge in New Edges is a duplicate of an edge in Existing Edges, add their uuids to the output list. + When finding duplicates edges, synthesize their facts into a short new fact. Guidelines: 1. identical or near identical facts are duplicates @@ -60,9 +62,11 @@ def v1(context: dict[str, Any]) -> list[Message]: Respond with a JSON object in the following format: {{ - "unique_facts": [ + "duplicates": [ {{ - "uuid": "unique identifier of the fact" + "uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39", + "duplicate_of": "uuid of the existing node", + "fact": "one sentence description of the fact" }} ] }} @@ -113,6 +117,40 @@ def v2(context: dict[str, Any]) -> list[Message]: ] +def v3(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates edges from edge lists.', + ), + Message( + role='user', + content=f""" + Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges. + + Existing Edges: + {json.dumps(context['existing_edges'], indent=2)} + + New Edge: + {json.dumps(context['extracted_edges'], indent=2)} + Task: + 1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the + response. Otherwise, return 'is_duplicate: false' + 2. If is_duplicate is true, also return the uuid of the existing edge in the response + + Guidelines: + 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information. + + Respond with a JSON object in the following format: + {{ + "is_duplicate": true or false, + "uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null, + }} + """, + ), + ] + + def edge_list(context: dict[str, Any]) -> list[Message]: return [ Message( @@ -151,4 +189,4 @@ def edge_list(context: dict[str, Any]) -> list[Message]: ] -versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list} +versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'edge_list': edge_list} diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py index 11e21e5f..1e7a4066 100644 --- a/graphiti_core/prompts/dedupe_nodes.py +++ b/graphiti_core/prompts/dedupe_nodes.py @@ -23,13 +23,15 @@ from .models import Message, PromptFunction, PromptVersion class Prompt(Protocol): v1: PromptVersion v2: PromptVersion + v3: PromptVersion node_list: PromptVersion class Versions(TypedDict): v1: PromptFunction v2: PromptFunction - node_list: PromptVersion + v3: PromptFunction + node_list: PromptFunction def v1(context: dict[str, Any]) -> list[Message]: @@ -94,22 +96,22 @@ def v2(context: dict[str, Any]) -> list[Message]: Important: If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!! Task: - If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list + If any node in New Nodes is a duplicate of a node in Existing Nodes, add their uuids to the output list When finding duplicates nodes, synthesize their summaries into a short new summary that contains the relevant information of the summaries of the new and existing nodes. Guidelines: 1. Use both the name and summary of nodes to determine if they are duplicates, duplicate nodes may have different names - 2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be - the name of the Existing Node. + 2. In the output, uuid should always be the uuid of the New Node that is a duplicate. duplicate_of should be + the uuid of the Existing Node. Respond with a JSON object in the following format: {{ "duplicates": [ {{ - "name": "name of the new node", - "duplicate_of": "name of the existing node", + "uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39", + "duplicate_of": "uuid of the existing node", "summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes" }} ] @@ -119,6 +121,44 @@ def v2(context: dict[str, Any]) -> list[Message]: ] +def v3(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates nodes from node lists.', + ), + Message( + role='user', + content=f""" + Given the following context, determine whether the New Node represents any of the entities in the list of Existing Nodes. + + Existing Nodes: + {json.dumps(context['existing_nodes'], indent=2)} + + New Node: + {json.dumps(context['extracted_nodes'], indent=2)} + Task: + 1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the + response. Otherwise, return 'is_duplicate: false' + 2. If is_duplicate is true, also return the uuid of the existing node in the response + 3. If is_duplicate is true, return a summary that synthesizes the information in the New Node summary and the + summary of the Existing Node it is a duplicate of. + + Guidelines: + 1. Use both the name and summary of nodes to determine if the entities are duplicates, + duplicate nodes may have different names + + Respond with a JSON object in the following format: + {{ + "is_duplicate": true or false, + "uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null", + "summary": "Brief summary of the node's role or significance. Takes information from the new and existing node" + }} + """, + ), + ] + + def node_list(context: dict[str, Any]) -> list[Message]: return [ Message( @@ -134,19 +174,19 @@ def node_list(context: dict[str, Any]) -> list[Message]: {json.dumps(context['nodes'], indent=2)} Task: - 1. Group nodes together such that all duplicate nodes are in the same list of names - 2. All duplicate names should be grouped together in the same list + 1. Group nodes together such that all duplicate nodes are in the same list of uuids + 2. All duplicate uuids should be grouped together in the same list 3. Also return a new summary that synthesizes the summary into a new short summary Guidelines: - 1. Each name from the list of nodes should appear EXACTLY once in your response - 2. If a node has no duplicates, it should appear in the response in a list of only one name + 1. Each uuid from the list of nodes should appear EXACTLY once in your response + 2. If a node has no duplicates, it should appear in the response in a list of only one uuid Respond with a JSON object in the following format: {{ "nodes": [ {{ - "names": ["myNode", "node that is a duplicate of myNode"], + "uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"], "summary": "Brief summary of the node summaries that appear in the list of names." }} ] @@ -156,4 +196,4 @@ def node_list(context: dict[str, Any]) -> list[Message]: ] -versions: Versions = {'v1': v1, 'v2': v2, 'node_list': node_list} +versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'node_list': node_list} diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index 1063d5fc..abd47d6b 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -55,6 +55,7 @@ def v1(context: dict[str, Any]) -> list[Message]: 1. Focus on entities, concepts, or actors that are central to the current episode. 2. Avoid creating nodes for relationships or actions (these will be handled as edges later). 3. Provide a brief but informative summary for each node. + 4. Be as explicit as possible in your node names, using full names and avoiding abbreviations. Respond with a JSON object in the following format: {{ @@ -90,6 +91,7 @@ Guidelines: 3. Provide concise but informative summaries for each extracted node. 4. Avoid creating nodes for relationships or actions. 5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later). +6. Be as explicit as possible in your node names, using full names and avoiding abbreviations. Respond with a JSON object in the following format: {{ diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index d1e6ed0e..f27249c4 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -83,7 +83,7 @@ async def hybrid_search( nodes.extend(await get_mentioned_nodes(driver, episodes)) if SearchMethod.bm25 in config.search_methods: - text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges) + text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges) search_results.append(text_search) if SearchMethod.cosine_similarity in config.search_methods: @@ -95,7 +95,7 @@ async def hybrid_search( ) similarity_search = await edge_similarity_search( - search_vector, driver, 2 * config.num_edges + driver, search_vector, 2 * config.num_edges ) search_results.append(similarity_search) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index ec15966f..0cfdb601 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -96,14 +96,18 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + driver: AsyncDriver, + search_vector: list[float], + limit: int = RELEVANT_SCHEMA_LIMIT, + source_node_uuid: str = '*', + target_node_uuid: str = '*', ) -> list[EntityEdge]: # vector similarity search over embedded facts records, _, _ = await driver.execute_query( """ CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS r, score - MATCH (n)-[r:RELATES_TO]->(m) + YIELD relationship AS rel, score + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) RETURN r.uuid AS uuid, n.uuid AS source_node_uuid, @@ -119,6 +123,8 @@ async def edge_similarity_search( ORDER BY score DESC """, search_vector=search_vector, + source_uuid=source_node_uuid, + target_uuid=target_node_uuid, limit=limit, ) @@ -214,7 +220,11 @@ async def entity_fulltext_search( async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + driver: AsyncDriver, + query: str, + limit=RELEVANT_SCHEMA_LIMIT, + source_node_uuid: str = '*', + target_node_uuid: str = '*', ) -> list[EntityEdge]: # fulltext search over facts fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -222,8 +232,8 @@ async def edge_fulltext_search( records, _, _ = await driver.execute_query( """ CALL db.index.fulltext.queryRelationships("name_and_fact", $query) - YIELD relationship AS r, score - MATCH (n:Entity)-[r]->(m:Entity) + YIELD relationship AS rel, score + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) RETURN r.uuid AS uuid, n.uuid AS source_node_uuid, @@ -239,6 +249,8 @@ async def edge_fulltext_search( ORDER BY score DESC LIMIT $limit """, query=fuzzy_query, + source_uuid=source_node_uuid, + target_uuid=target_node_uuid, limit=limit, ) @@ -369,6 +381,9 @@ async def get_relevant_nodes( async def get_relevant_edges( edges: list[EntityEdge], driver: AsyncDriver, + limit: int = RELEVANT_SCHEMA_LIMIT, + source_node_uuid: str = '*', + target_node_uuid: str = '*', ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -376,11 +391,16 @@ async def get_relevant_edges( results = await asyncio.gather( *[ - edge_similarity_search(edge.fact_embedding, driver) + edge_similarity_search( + driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid + ) for edge in edges if edge.fact_embedding is not None ], - *[edge_fulltext_search(edge.fact, driver) for edge in edges], + *[ + edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid) + for edge in edges + ], ) for result in results: diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 1a2496d7..50c702aa 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -17,7 +17,6 @@ limitations under the License. import asyncio import logging import typing -from collections import defaultdict from datetime import datetime from math import ceil @@ -43,6 +42,7 @@ from graphiti_core.utils.maintenance.node_operations import ( extract_nodes, ) from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates +from graphiti_core.utils.utils import chunk_edges_by_nodes logger = logging.getLogger(__name__) @@ -128,7 +128,7 @@ async def dedupe_nodes_bulk( ) ) - results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list( + results: list[tuple[list[EntityNode], dict[str, str]]] = list( await asyncio.gather( *[ dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i]) @@ -265,19 +265,7 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list return edges # We only want to dedupe edges that are between the same pair of nodes # We build a map of the edges based on their source and target nodes. - edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) - for edge in edges: - # We drop loop edges - if edge.source_node_uuid == edge.target_node_uuid: - continue - - # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution - pointers = [edge.source_node_uuid, edge.target_node_uuid] - pointers.sort() - - edge_chunk_map[pointers[0] + pointers[1]].append(edge) - - edge_chunks = [chunk for chunk in edge_chunk_map.values()] + edge_chunks = chunk_edges_by_nodes(edges) results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 848f6250..00cb852f 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import logging from datetime import datetime from time import time @@ -109,8 +110,8 @@ async def dedupe_extracted_edges( existing_edges: list[EntityEdge], ) -> list[EntityEdge]: # Create edge map - edge_map = {} - for edge in extracted_edges: + edge_map: dict[str, EntityEdge] = {} + for edge in existing_edges: edge_map[edge.uuid] = edge # Prepare context for LLM @@ -124,18 +125,85 @@ async def dedupe_extracted_edges( } llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context)) - unique_edge_data = llm_response.get('unique_facts', []) - logger.info(f'Extracted unique edges: {unique_edge_data}') + duplicate_data = llm_response.get('duplicates', []) + logger.info(f'Extracted unique edges: {duplicate_data}') + + duplicate_uuid_map: dict[str, str] = {} + for duplicate in duplicate_data: + uuid_value = duplicate['duplicate_of'] + duplicate_uuid_map[duplicate['uuid']] = uuid_value # Get full edge data - edges = [] - for unique_edge in unique_edge_data: - edge = edge_map[unique_edge['uuid']] - edges.append(edge) + edges: list[EntityEdge] = [] + for edge in extracted_edges: + if edge.uuid in duplicate_uuid_map: + existing_uuid = duplicate_uuid_map[edge.uuid] + existing_edge = edge_map[existing_uuid] + edges.append(existing_edge) + else: + edges.append(edge) return edges +async def resolve_extracted_edges( + llm_client: LLMClient, + extracted_edges: list[EntityEdge], + existing_edges_lists: list[list[EntityEdge]], +) -> list[EntityEdge]: + resolved_edges: list[EntityEdge] = list( + await asyncio.gather( + *[ + resolve_extracted_edge(llm_client, extracted_edge, existing_edges) + for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists) + ] + ) + ) + + return resolved_edges + + +async def resolve_extracted_edge( + llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge] +) -> EntityEdge: + start = time() + + # Prepare context for LLM + existing_edges_context = [ + {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges + ] + + extracted_edge_context = { + 'uuid': extracted_edge.uuid, + 'name': extracted_edge.name, + 'fact': extracted_edge.fact, + } + + context = { + 'existing_edges': existing_edges_context, + 'extracted_edges': extracted_edge_context, + } + + llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v3(context)) + + is_duplicate: bool = llm_response.get('is_duplicate', False) + uuid: str | None = llm_response.get('uuid', None) + + edge = extracted_edge + if is_duplicate: + for existing_edge in existing_edges: + if existing_edge.uuid != uuid: + continue + edge = existing_edge + + end = time() + logger.info( + f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms' + ) + + return edge + + async def dedupe_edge_list( llm_client: LLMClient, edges: list[EntityEdge], diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index a1291237..d1e8f43b 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import logging from datetime import datetime from time import time @@ -27,7 +28,7 @@ logger = logging.getLogger(__name__) async def extract_message_nodes( - llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode] + llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode] ) -> list[dict[str, Any]]: # Prepare context for LLM context = { @@ -48,8 +49,8 @@ async def extract_message_nodes( async def extract_json_nodes( - llm_client: LLMClient, - episode: EpisodicNode, + llm_client: LLMClient, + episode: EpisodicNode, ) -> list[dict[str, Any]]: # Prepare context for LLM context = { @@ -66,9 +67,9 @@ async def extract_json_nodes( async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: start = time() extracted_node_data: list[dict[str, Any]] = [] @@ -95,29 +96,24 @@ async def extract_nodes( async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], -) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]: + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], +) -> tuple[list[EntityNode], dict[str, str]]: start = time() # build existing node map node_map: dict[str, EntityNode] = {} for node in existing_nodes: - node_map[node.name] = node - - # Temp hack - new_nodes_map: dict[str, EntityNode] = {} - for node in extracted_nodes: - new_nodes_map[node.name] = node + node_map[node.uuid] = node # Prepare context for LLM existing_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in existing_nodes + {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes ] extracted_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in extracted_nodes + {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes ] context = { @@ -134,42 +130,104 @@ async def dedupe_extracted_nodes( uuid_map: dict[str, str] = {} for duplicate in duplicate_data: - uuid = new_nodes_map[duplicate['name']].uuid - uuid_value = node_map[duplicate['duplicate_of']].uuid - uuid_map[uuid] = uuid_value + uuid_value = duplicate['duplicate_of'] + uuid_map[duplicate['uuid']] = uuid_value nodes: list[EntityNode] = [] - brand_new_nodes: list[EntityNode] = [] for node in extracted_nodes: if node.uuid in uuid_map: existing_uuid = uuid_map[node.uuid] - # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, - # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? - # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) - existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None) - if existing_node: - nodes.append(existing_node) + existing_node = node_map[existing_uuid] + nodes.append(existing_node) + else: + nodes.append(node) - continue - brand_new_nodes.append(node) - nodes.append(node) + return nodes, uuid_map - return nodes, uuid_map, brand_new_nodes + +async def resolve_extracted_nodes( + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes_lists: list[list[EntityNode]], +) -> tuple[list[EntityNode], dict[str, str]]: + uuid_map: dict[str, str] = {} + resolved_nodes: list[EntityNode] = [] + results: list[tuple[EntityNode, dict[str, str]]] = list( + await asyncio.gather( + *[ + resolve_extracted_node(llm_client, extracted_node, existing_nodes) + for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists) + ] + ) + ) + + for result in results: + uuid_map.update(result[1]) + resolved_nodes.append(result[0]) + + return resolved_nodes, uuid_map + + +async def resolve_extracted_node( + llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode] +) -> tuple[EntityNode, dict[str, str]]: + start = time() + + # Prepare context for LLM + existing_nodes_context = [ + {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes + ] + + extracted_node_context = { + 'uuid': extracted_node.uuid, + 'name': extracted_node.name, + 'summary': extracted_node.summary, + } + + context = { + 'existing_nodes': existing_nodes_context, + 'extracted_nodes': extracted_node_context, + } + + llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v3(context)) + + is_duplicate: bool = llm_response.get('is_duplicate', False) + uuid: str | None = llm_response.get('uuid', None) + summary = llm_response.get('summary', '') + + node = extracted_node + uuid_map: dict[str, str] = {} + if is_duplicate: + for existing_node in existing_nodes: + if existing_node.uuid != uuid: + continue + node = existing_node + node.summary = summary + uuid_map[extracted_node.uuid] = existing_node.uuid + + end = time() + logger.info( + f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms' + ) + + return node, uuid_map async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: start = time() # build node map node_map = {} for node in nodes: - node_map[node.name] = node + node_map[node.uuid] = node # Prepare context for LLM - nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes] + nodes_context = [ + {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes + ] context = { 'nodes': nodes_context, @@ -188,13 +246,12 @@ async def dedupe_node_list( unique_nodes = [] uuid_map: dict[str, str] = {} for node_data in nodes_data: - node = node_map[node_data['names'][0]] + node = node_map[node_data['uuids'][0]] node.summary = node_data['summary'] unique_nodes.append(node) - for name in node_data['names'][1:]: - uuid = node_map[name].uuid - uuid_value = node_map[node_data['names'][0]].uuid + for uuid in node_data['uuids'][1:]: + uuid_value = node_map[node_data['uuids'][0]].uuid uuid_map[uuid] = uuid_value return unique_nodes, uuid_map diff --git a/graphiti_core/utils/maintenance/temporal_operations.py b/graphiti_core/utils/maintenance/temporal_operations.py index 8ac2a68b..ef168209 100644 --- a/graphiti_core/utils/maintenance/temporal_operations.py +++ b/graphiti_core/utils/maintenance/temporal_operations.py @@ -149,7 +149,7 @@ async def extract_edge_dates( edge: EntityEdge, current_episode: EpisodicNode, previous_episodes: List[EpisodicNode], -) -> tuple[datetime | None, datetime | None, str]: +) -> tuple[datetime | None, datetime | None]: context = { 'edge_name': edge.name, 'edge_fact': edge.fact, @@ -180,4 +180,4 @@ async def extract_edge_dates( logger.info(f'Edge date extraction explanation: {explanation}') - return valid_at_datetime, invalid_at_datetime, explanation + return valid_at_datetime, invalid_at_datetime diff --git a/graphiti_core/utils/utils.py b/graphiti_core/utils/utils.py index 4dbdf84f..97821279 100644 --- a/graphiti_core/utils/utils.py +++ b/graphiti_core/utils/utils.py @@ -15,8 +15,9 @@ limitations under the License. """ import logging +from collections import defaultdict -from graphiti_core.edges import EpisodicEdge +from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.nodes import EntityNode, EpisodicNode logger = logging.getLogger(__name__) @@ -37,3 +38,23 @@ def build_episodic_edges( ) return edges + + +def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]: + # We only want to dedupe edges that are between the same pair of nodes + # We build a map of the edges based on their source and target nodes. + edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) + for edge in edges: + # We drop loop edges + if edge.source_node_uuid == edge.target_node_uuid: + continue + + # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution + pointers = [edge.source_node_uuid, edge.target_node_uuid] + pointers.sort() + + edge_chunk_map[pointers[0] + pointers[1]].append(edge) + + edge_chunks = [chunk for chunk in edge_chunk_map.values()] + + return edge_chunks