diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 5d2ad20c..8c8cccbb 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -29,6 +29,7 @@ from graphiti_core.llm_client.utils import generate_embedding from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search from graphiti_core.search.search_utils import ( + RELEVANT_SCHEMA_LIMIT, get_relevant_edges, get_relevant_nodes, hybrid_node_search, @@ -41,6 +42,7 @@ from graphiti_core.utils.bulk_utils import ( RawEpisode, dedupe_edges_bulk, dedupe_nodes_bulk, + extract_edge_dates_bulk, extract_nodes_and_edges_bulk, resolve_edge_pointers, retrieve_previous_episodes_bulk, @@ -319,26 +321,24 @@ class Graphiti: valid_at, invalid_at, _ = await extract_edge_dates( self.llm_client, edge, - episode.valid_at, episode, previous_episodes, ) edge.valid_at = valid_at edge.invalid_at = invalid_at if edge.invalid_at: - edge.expired_at = datetime.now() + edge.expired_at = now for edge in existing_edges: valid_at, invalid_at, _ = await extract_edge_dates( self.llm_client, edge, - episode.valid_at, episode, previous_episodes, ) edge.valid_at = valid_at edge.invalid_at = invalid_at if edge.invalid_at: - edge.expired_at = datetime.now() + edge.expired_at = now ( old_edges_with_nodes_pending_invalidation, new_edges_with_nodes, @@ -481,15 +481,18 @@ class Graphiti: *[edge.generate_embedding(embedder) for edge in extracted_edges], ) - # Dedupe extracted nodes - nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes) + # Dedupe extracted nodes, compress extracted edges + (nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather( + dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes), + extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs), + ) # save nodes to KG await asyncio.gather(*[node.save(self.driver) for node in nodes]) # re-map edge pointers so that they don't point to discard dupe nodes extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( - extracted_edges, uuid_map + extracted_edges_timestamped, uuid_map ) episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers( episodic_edges, uuid_map @@ -579,7 +582,9 @@ class Graphiti: self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid ) - async def get_nodes_by_query(self, query: str, limit: int | None = None) -> list[EntityNode]: + async def get_nodes_by_query( + self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT + ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py index f7b45629..11e21e5f 100644 --- a/graphiti_core/prompts/dedupe_nodes.py +++ b/graphiti_core/prompts/dedupe_nodes.py @@ -53,7 +53,9 @@ def v1(context: dict[str, Any]) -> list[Message]: 1. start with the list of nodes from New Nodes 2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing node in the list - 3. Respond with the resulting list of nodes + 3. when deduplicating nodes, synthesize their summaries into a short new summary that contains the relevant information + of the summaries of the new and existing nodes + 4. Respond with the resulting list of nodes Guidelines: 1. Use both the name and summary of nodes to determine if they are duplicates, @@ -64,6 +66,7 @@ def v1(context: dict[str, Any]) -> list[Message]: "new_nodes": [ {{ "name": "Unique identifier for the node", + "summary": "Brief summary of the node's role or significance" }} ] }} @@ -92,6 +95,8 @@ def v2(context: dict[str, Any]) -> list[Message]: If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!! Task: If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list + When finding duplicates nodes, synthesize their summaries into a short new summary that contains the + relevant information of the summaries of the new and existing nodes. Guidelines: 1. Use both the name and summary of nodes to determine if they are duplicates, @@ -104,7 +109,8 @@ def v2(context: dict[str, Any]) -> list[Message]: "duplicates": [ {{ "name": "name of the new node", - "duplicate_of": "name of the existing node" + "duplicate_of": "name of the existing node", + "summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes" }} ] }} @@ -130,6 +136,7 @@ def node_list(context: dict[str, Any]) -> list[Message]: Task: 1. Group nodes together such that all duplicate nodes are in the same list of names 2. All duplicate names should be grouped together in the same list + 3. Also return a new summary that synthesizes the summary into a new short summary Guidelines: 1. Each name from the list of nodes should appear EXACTLY once in your response @@ -140,6 +147,7 @@ def node_list(context: dict[str, Any]) -> list[Message]: "nodes": [ {{ "names": ["myNode", "node that is a duplicate of myNode"], + "summary": "Brief summary of the node summaries that appear in the list of names." }} ] }} diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py index f995ba2f..cef5e38f 100644 --- a/graphiti_core/prompts/extract_edges.py +++ b/graphiti_core/prompts/extract_edges.py @@ -110,10 +110,11 @@ def v2(context: dict[str, Any]) -> list[Message]: Guidelines: 1. Create edges only between the provided nodes. - 2. Each edge should represent a clear relationship between two nodes. + 2. Each edge should represent a clear relationship between two DISTINCT nodes. 3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR). 4. Provide a more detailed fact describing the relationship. 5. Consider temporal aspects of relationships when relevant. + 6. Avoid using the same node as the source and target of a relationship Respond with a JSON object in the following format: {{ diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 03ed04a4..d1e6ed0e 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -63,12 +63,12 @@ class SearchResults(BaseModel): async def hybrid_search( - driver: AsyncDriver, - embedder, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + driver: AsyncDriver, + embedder, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 6eeb5cea..9937a3cd 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -268,13 +268,13 @@ async def hybrid_node_search( queries: list[str], embeddings: list[list[float]], driver: AsyncDriver, - limit: int | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Perform a hybrid search for nodes using both text queries and embeddings. This method combines fulltext search and vector similarity search to find - relevant nodes in the graph database. + relevant nodes in the graph database. It uses an rrf reranker. Parameters ---------- @@ -307,27 +307,25 @@ async def hybrid_node_search( """ start = time() - relevant_nodes: list[EntityNode] = [] - relevant_node_uuids = set() - results = await asyncio.gather( - *[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries], - *[ - entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) - for e in embeddings - ], + results: list[list[EntityNode]] = list( + await asyncio.gather( + *[entity_fulltext_search(q, driver, 2 * limit) for q in queries], + *[entity_similarity_search(e, driver, 2 * limit) for e in embeddings], + ) ) - for result in results: - for node in result: - if node.uuid in relevant_node_uuids: - continue + node_uuid_map: dict[str, EntityNode] = { + node.uuid: node for result in results for node in result + } + result_uuids = [[node.uuid for node in result] for result in results] - relevant_node_uuids.add(node.uuid) - relevant_nodes.append(node) + ranked_uuids = rrf(result_uuids) + + relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids] end = time() - logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms') + logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms') return relevant_nodes diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index edd39bd2..ca8cf5d9 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -15,11 +15,14 @@ limitations under the License. """ import asyncio +import logging import typing +from collections import defaultdict from datetime import datetime +from math import ceil from neo4j import AsyncDriver -from numpy import dot +from numpy import dot, sqrt from pydantic import BaseModel from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge @@ -39,8 +42,11 @@ from graphiti_core.utils.maintenance.node_operations import ( dedupe_node_list, extract_nodes, ) +from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates -CHUNK_SIZE = 15 +logger = logging.getLogger(__name__) + +CHUNK_SIZE = 10 class RawEpisode(BaseModel): @@ -52,7 +58,7 @@ class RawEpisode(BaseModel): async def retrieve_previous_episodes_bulk( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: previous_episodes_list = await asyncio.gather( *[ @@ -68,7 +74,7 @@ async def retrieve_previous_episodes_bulk( async def extract_nodes_and_edges_bulk( - llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] + llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: extracted_nodes_bulk = await asyncio.gather( *[ @@ -105,36 +111,67 @@ async def extract_nodes_and_edges_bulk( async def dedupe_nodes_bulk( - driver: AsyncDriver, - llm_client: LLMClient, - extracted_nodes: list[EntityNode], + driver: AsyncDriver, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: # Compress nodes nodes, uuid_map = node_name_match(extracted_nodes) compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map) - existing_nodes = await get_relevant_nodes(compressed_nodes, driver) + node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] - nodes, partial_uuid_map, _ = await dedupe_extracted_nodes( - llm_client, compressed_nodes, existing_nodes + existing_nodes_chunks: list[list[EntityNode]] = list( + await asyncio.gather( + *[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks] + ) ) - compressed_map.update(partial_uuid_map) + results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list( + await asyncio.gather( + *[ + dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i]) + for i, node_chunk in enumerate(node_chunks) + ] + ) + ) - return nodes, compressed_map + final_nodes: list[EntityNode] = [] + for result in results: + final_nodes.extend(result[0]) + partial_uuid_map = result[1] + compressed_map.update(partial_uuid_map) + + return final_nodes, compressed_map async def dedupe_edges_bulk( - driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] + driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] ) -> list[EntityEdge]: - # Compress edges + # First compress edges compressed_edges = await compress_edges(llm_client, extracted_edges) - existing_edges = await get_relevant_edges(compressed_edges, driver) + edge_chunks = [ + compressed_edges[i: i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE) + ] - edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges) + relevant_edges_chunks: list[list[EntityEdge]] = list( + await asyncio.gather( + *[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks] + ) + ) + resolved_edge_chunks: list[list[EntityEdge]] = list( + await asyncio.gather( + *[ + dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i]) + for i, edge_chunk in enumerate(edge_chunks) + ] + ) + ) + + edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk] return edges @@ -152,15 +189,60 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str async def compress_nodes( - llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] + llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] ) -> tuple[list[EntityNode], dict[str, str]]: + # We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk if len(nodes) == 0: return nodes, uuid_map - anchor = nodes[0] - nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or [])) + # Our approach involves us deduplicating chunks of nodes in parallel. + # We want n chunks of size n so that n ** 2 == len(nodes). + # We want chunk sizes to be at least 10 for optimizing LLM processing time + chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE) - node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] + # First calculate similarity scores between nodes + similarity_scores: list[tuple[int, int, float]] = [ + (i, j, dot(n.name_embedding or [], m.name_embedding or [])) + for i, n in enumerate(nodes) + for j, m in enumerate(nodes[:i]) + ] + + # We now sort by semantic similarity + similarity_scores.sort(key=lambda score_tuple: score_tuple[2]) + + # initialize our chunks based on chunk size + node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))] + + # Draft the most similar nodes into the same chunk + while len(similarity_scores) > 0: + i, j, _ = similarity_scores.pop() + # determine if any of the nodes have already been drafted into a chunk + n = nodes[i] + m = nodes[j] + # make sure the shortest chunks get preference + node_chunks.sort(reverse=True, key=lambda chunk: len(chunk)) + + n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)]) + m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)]) + + # both nodes already in a chunk + if n_chunk > -1 and m_chunk > -1: + continue + + # n has a chunk and that chunk is not full + elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size: + # put m in the same chunk as n + node_chunks[n_chunk].append(m) + + # m has a chunk and that chunk is not full + elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size: + # put n in the same chunk as m + node_chunks[m_chunk].append(n) + + # neither node has a chunk or the chunk is full + else: + # add both nodes to the shortest chunk + node_chunks[-1].extend([n, m]) results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) @@ -181,13 +263,21 @@ async def compress_nodes( async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]: if len(edges) == 0: return edges + # We only want to dedupe edges that are between the same pair of nodes + # We build a map of the edges based on their source and target nodes. + edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) + for edge in edges: + # We drop loop edges + if edge.source_node_uuid == edge.target_node_uuid: + continue - anchor = edges[0] - edges.sort( - key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or []) - ) + # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution + pointers = [edge.source_node_uuid, edge.target_node_uuid] + pointers.sort() - edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] + edge_chunk_map[pointers[0] + pointers[1]].append(edge) + + edge_chunks = [chunk for chunk in edge_chunk_map.values()] results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) @@ -225,3 +315,43 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]): edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid) return edges + + +async def extract_edge_dates_bulk( + llm_client: LLMClient, + extracted_edges: list[EntityEdge], + episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]], +) -> list[EntityEdge]: + edges: list[EntityEdge] = [] + # confirm that all of our edges have at least one episode + for edge in extracted_edges: + if edge.episodes is not None and len(edge.episodes) > 0: + edges.append(edge) + + episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = { + episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs + } + + results = await asyncio.gather( + *[ + extract_edge_dates( + llm_client, + edge, + episode_uuid_map[edge.episodes[0]][0], # type: ignore + episode_uuid_map[edge.episodes[0]][1], # type: ignore + ) + for edge in edges + ] + ) + + for i, result in enumerate(results): + valid_at = result[0] + invalid_at = result[1] + edge = edges[i] + + edge.valid_at = valid_at + edge.invalid_at = invalid_at + if edge.invalid_at: + edge.expired_at = datetime.now() + + return edges diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 2dfdaccb..a1291237 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -189,6 +189,7 @@ async def dedupe_node_list( uuid_map: dict[str, str] = {} for node_data in nodes_data: node = node_map[node_data['names'][0]] + node.summary = node_data['summary'] unique_nodes.append(node) for name in node_data['names'][1:]: diff --git a/graphiti_core/utils/maintenance/temporal_operations.py b/graphiti_core/utils/maintenance/temporal_operations.py index 021711cf..8ac2a68b 100644 --- a/graphiti_core/utils/maintenance/temporal_operations.py +++ b/graphiti_core/utils/maintenance/temporal_operations.py @@ -147,7 +147,6 @@ def process_edge_invalidation_llm_response( async def extract_edge_dates( llm_client: LLMClient, edge: EntityEdge, - reference_time: datetime, current_episode: EpisodicNode, previous_episodes: List[EpisodicNode], ) -> tuple[datetime | None, datetime | None, str]: @@ -156,7 +155,7 @@ async def extract_edge_dates( 'edge_fact': edge.fact, 'current_episode': current_episode.content, 'previous_episodes': [ep.content for ep in previous_episodes], - 'reference_timestamp': reference_time.isoformat(), + 'reference_timestamp': current_episode.valid_at.isoformat(), } llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))