From ebe5b1e0d2a98e269419e07adb05da54440d309a Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 24 Jun 2025 22:16:06 +0800 Subject: [PATCH 1/5] Bump api version to 0175 --- lightrag/api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index 7a4a498a..9a96af72 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "0174" +__api_version__ = "0175" From da46b341dc1b2c6c578439374ed45a30bea493db Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 25 Jun 2025 12:37:57 +0800 Subject: [PATCH 2/5] feat: Optimize document deletion performance - To enhance performance during document deletion, new batch-get methods, `get_nodes_by_chunk_ids` and `get_edges_by_chunk_ids`, have been added to the graph storage layer (`BaseGraphStorage` and its implementations). The [`adelete_by_doc_id`](lightrag/lightrag.py:1681) function now leverages these methods to avoid unnecessary iteration over the entire knowledge graph, significantly improving efficiency. - Graph storage updated: Networkx, Neo4j, Postgres AGE --- lightrag/base.py | 62 ++++++++++++++++++++ lightrag/constants.py | 3 + lightrag/kg/neo4j_impl.py | 42 ++++++++++++++ lightrag/kg/networkx_impl.py | 28 +++++++++ lightrag/kg/postgres_impl.py | 107 +++++++++++++++++++++++++++++++---- lightrag/lightrag.py | 88 ++++++++++++---------------- lightrag/operate.py | 3 +- lightrag/prompt.py | 1 - 8 files changed, 271 insertions(+), 63 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 84fc7564..add2318e 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -14,6 +14,7 @@ from typing import ( ) from .utils import EmbeddingFunc from .types import KnowledgeGraph +from .constants import GRAPH_FIELD_SEP # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance @@ -456,6 +457,67 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[node_id] = edges if edges is not None else [] return result + @abstractmethod + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ + # Default implementation iterates through all nodes, which is inefficient. + # This method should be overridden by subclasses for better performance. + all_nodes = [] + all_labels = await self.get_all_labels() + for label in all_labels: + node = await self.get_node(label) + if node and "source_id" in node: + source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP)) + if not source_ids.isdisjoint(chunk_ids): + all_nodes.append(node) + return all_nodes + + @abstractmethod + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated edges for. + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ + # Default implementation iterates through all nodes and their edges, which is inefficient. + # This method should be overridden by subclasses for better performance. + all_edges = [] + all_labels = await self.get_all_labels() + processed_edges = set() + + for label in all_labels: + edges = await self.get_node_edges(label) + if edges: + for src_id, tgt_id in edges: + # Avoid processing the same edge twice in an undirected graph + edge_tuple = tuple(sorted((src_id, tgt_id))) + if edge_tuple in processed_edges: + continue + processed_edges.add(edge_tuple) + + edge = await self.get_edge(src_id, tgt_id) + if edge and "source_id" in edge: + source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP)) + if not source_ids.isdisjoint(chunk_ids): + # Add source and target to the edge dict for easier processing later + edge_with_nodes = edge.copy() + edge_with_nodes["source"] = src_id + edge_with_nodes["target"] = tgt_id + all_edges.append(edge_with_nodes) + return all_edges + @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Insert a new node or update an existing node in the graph. diff --git a/lightrag/constants.py b/lightrag/constants.py index 787e1c49..f8345994 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -12,6 +12,9 @@ DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6 DEFAULT_WOKERS = 2 DEFAULT_TIMEOUT = 150 +# Separator for graph fields +GRAPH_FIELD_SEP = "" + # Logging configuration defaults DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 7fe3da15..d4fbc59c 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -16,6 +16,7 @@ import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP import pipmaster as pm if not pm.is_installed("neo4j"): @@ -725,6 +726,47 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are fully consumed return edges_dict + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (a:base)-[r]-(b:base) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index c92bbd30..a4c46122 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -5,6 +5,7 @@ from typing import final from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage +from lightrag.constants import GRAPH_FIELD_SEP import pipmaster as pm @@ -357,6 +358,33 @@ class NetworkXStorage(BaseGraphStorage): ) return result + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + chunk_ids_set = set(chunk_ids) + graph = await self._get_graph() + matching_nodes = [] + for node_id, node_data in graph.nodes(data=True): + if "source_id" in node_data: + node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) + if not node_source_ids.isdisjoint(chunk_ids_set): + node_data_with_id = node_data.copy() + node_data_with_id["id"] = node_id + matching_nodes.append(node_data_with_id) + return matching_nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + chunk_ids_set = set(chunk_ids) + graph = await self._get_graph() + matching_edges = [] + for u, v, edge_data in graph.edges(data=True): + if "source_id" in edge_data: + edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) + if not edge_source_ids.isdisjoint(chunk_ids_set): + edge_data_with_nodes = edge_data.copy() + edge_data_with_nodes["source"] = u + edge_data_with_nodes["target"] = v + matching_edges.append(edge_data_with_nodes) + return matching_edges + async def index_done_callback(self) -> bool: """Save data to disk""" async with self._storage_lock: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index bacd8894..888b97c7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -27,6 +27,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..constants import GRAPH_FIELD_SEP import pipmaster as pm @@ -1422,8 +1423,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(node_dict, str): try: - import json - node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning(f"Failed to parse node string: {node_dict}") @@ -1479,8 +1478,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(result, str): try: - import json - result = json.loads(result) except json.JSONDecodeError: logger.warning(f"Failed to parse edge string: {result}") @@ -1697,8 +1694,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(node_dict, str): try: - import json - node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning( @@ -1861,8 +1856,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(edge_props, str): try: - import json - edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( @@ -1879,8 +1872,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(edge_props, str): try: - import json - edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( @@ -1975,6 +1966,102 @@ class PGGraphStorage(BaseGraphStorage): labels.append(result["label"]) return labels + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """ + Retrieves nodes from the graph that are associated with a given list of chunk IDs. + This method uses a Cypher query with UNWIND to efficiently find all nodes + where the `source_id` property contains any of the specified chunk IDs. + """ + # The string representation of the list for the cypher query + chunk_ids_str = json.dumps(chunk_ids) + + query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND {chunk_ids_str} AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}') + RETURN n + $$) AS (n agtype); + """ + results = await self._query(query) + + # Build result list + nodes = [] + for result in results: + if result["n"]: + node_dict = result["n"]["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {node_dict}" + ) + + node_dict["id"] = node_dict["entity_id"] + nodes.append(node_dict) + + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """ + Retrieves edges from the graph that are associated with a given list of chunk IDs. + This method uses a Cypher query with UNWIND to efficiently find all edges + where the `source_id` property contains any of the specified chunk IDs. + """ + chunk_ids_str = json.dumps(chunk_ids) + + query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND {chunk_ids_str} AS chunk_id + MATCH (a:base)-[r]-(b:base) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}') + RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target + $$) AS (edge agtype, source agtype, target agtype); + """ + results = await self._query(query) + edges = [] + if results: + for item in results: + edge_agtype = item["edge"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(edge_agtype, str): + try: + edge_agtype = json.loads(edge_agtype) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge string in batch: {edge_agtype}" + ) + + source_agtype = item["source"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(source_agtype, str): + try: + source_agtype = json.loads(source_agtype) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {source_agtype}" + ) + + target_agtype = item["target"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(target_agtype, str): + try: + target_agtype = json.loads(target_agtype) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {target_agtype}" + ) + + if edge_agtype and source_agtype and target_agtype: + edge_properties = edge_agtype + edge_properties["source"] = source_agtype["entity_id"] + edge_properties["target"] = target_agtype["entity_id"] + edges.append(edge_properties) + return edges + async def _bfs_subgraph( self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index f631992d..b94709f2 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -60,7 +60,7 @@ from .operate import ( query_with_keywords, _rebuild_knowledge_from_chunks, ) -from .prompt import GRAPH_FIELD_SEP +from .constants import GRAPH_FIELD_SEP from .utils import ( Tokenizer, TiktokenTokenizer, @@ -1761,68 +1761,54 @@ class LightRAG: # Use graph database lock to ensure atomic merges and updates graph_db_lock = get_graph_db_lock(enable_logging=False) async with graph_db_lock: - # Process entities - # TODO There is performance when iterating get_all_labels for PostgresSQL - all_labels = await self.chunk_entity_relation_graph.get_all_labels() - for node_label in all_labels: - node_data = await self.chunk_entity_relation_graph.get_node( - node_label + # Get all affected nodes and edges in batch + affected_nodes = ( + await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids( + list(chunk_ids) ) - if node_data and "source_id" in node_data: - # Split source_id using GRAPH_FIELD_SEP + ) + affected_edges = ( + await self.chunk_entity_relation_graph.get_edges_by_chunk_ids( + list(chunk_ids) + ) + ) + + # logger.info(f"chunk_ids: {chunk_ids}") + # logger.info(f"affected_nodes: {affected_nodes}") + # logger.info(f"affected_edges: {affected_edges}") + + # Process entities + for node_data in affected_nodes: + node_label = node_data.get("entity_id") + if node_label and "source_id" in node_data: sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) remaining_sources = sources - chunk_ids if not remaining_sources: entities_to_delete.add(node_label) - logger.debug( - f"Entity {node_label} marked for deletion - no remaining sources" - ) elif remaining_sources != sources: - # Entity needs to be rebuilt from remaining chunks entities_to_rebuild[node_label] = remaining_sources - logger.debug( - f"Entity {node_label} will be rebuilt from {len(remaining_sources)} remaining chunks" - ) # Process relationships - # TODO There is performance when iterating get_all_labels for PostgresSQL - for node_label in all_labels: - node_edges = await self.chunk_entity_relation_graph.get_node_edges( - node_label - ) - if node_edges: - for src, tgt in node_edges: - # To avoid processing the same edge twice in an undirected graph - if (tgt, src) in relationships_to_delete or ( - tgt, - src, - ) in relationships_to_rebuild: - continue + for edge_data in affected_edges: + src = edge_data.get("source") + tgt = edge_data.get("target") - edge_data = await self.chunk_entity_relation_graph.get_edge( - src, tgt - ) - if edge_data and "source_id" in edge_data: - # Split source_id using GRAPH_FIELD_SEP - sources = set( - edge_data["source_id"].split(GRAPH_FIELD_SEP) - ) - remaining_sources = sources - chunk_ids + if src and tgt and "source_id" in edge_data: + edge_tuple = tuple(sorted((src, tgt))) + if ( + edge_tuple in relationships_to_delete + or edge_tuple in relationships_to_rebuild + ): + continue - if not remaining_sources: - relationships_to_delete.add((src, tgt)) - logger.debug( - f"Relationship {src}-{tgt} marked for deletion - no remaining sources" - ) - elif remaining_sources != sources: - # Relationship needs to be rebuilt from remaining chunks - relationships_to_rebuild[(src, tgt)] = ( - remaining_sources - ) - logger.debug( - f"Relationship {src}-{tgt} will be rebuilt from {len(remaining_sources)} remaining chunks" - ) + sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) + remaining_sources = sources - chunk_ids + + if not remaining_sources: + relationships_to_delete.add(edge_tuple) + elif remaining_sources != sources: + relationships_to_rebuild[edge_tuple] = remaining_sources # 5. Delete chunks from storage if chunk_ids: diff --git a/lightrag/operate.py b/lightrag/operate.py index b19f739c..d5026203 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -33,7 +33,8 @@ from .base import ( TextChunkSchema, QueryParam, ) -from .prompt import GRAPH_FIELD_SEP, PROMPTS +from .prompt import PROMPTS +from .constants import GRAPH_FIELD_SEP import time from dotenv import load_dotenv diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 5ed630f9..a4641480 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any -GRAPH_FIELD_SEP = "" PROMPTS: dict[str, Any] = {} From 109c2b48bef93a69b4ca3010d6351e0b42ddb8d2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 25 Jun 2025 12:39:43 +0800 Subject: [PATCH 3/5] Fix linting --- lightrag/kg/postgres_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 888b97c7..dd93d624 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -2000,7 +2000,7 @@ class PGGraphStorage(BaseGraphStorage): f"Failed to parse node string in batch: {node_dict}" ) - node_dict["id"] = node_dict["entity_id"] + node_dict["id"] = node_dict["entity_id"] nodes.append(node_dict) return nodes From 492269ac446db0a5ca4e3dee6e0823c42db37c50 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 25 Jun 2025 12:39:57 +0800 Subject: [PATCH 4/5] Bump core version to 1.3.10 --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 1b88de1e..392b3f60 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.3.9" +__version__ = "1.3.10" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 72384f87c4f26c2e30d3a0fa635bba8a7de6d4b7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 25 Jun 2025 12:53:07 +0800 Subject: [PATCH 5/5] Remove deprecated code from Postgres_impl.py - Stop filtering out 'base' node labels - Match any edge type in query to improve performance --- lightrag/kg/postgres_impl.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index dd93d624..0ddc7948 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1701,10 +1701,11 @@ class PGGraphStorage(BaseGraphStorage): ) # Remove the 'base' label if present in a 'labels' property - if "labels" in node_dict: - node_dict["labels"] = [ - label for label in node_dict["labels"] if label != "base" - ] + # if "labels" in node_dict: + # node_dict["labels"] = [ + # label for label in node_dict["labels"] if label != "base" + # ] + nodes_dict[result["node_id"]] = node_dict return nodes_dict @@ -1833,14 +1834,14 @@ class PGGraphStorage(BaseGraphStorage): forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ WITH [{src_array}] AS sources, [{tgt_array}] AS targets UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}}) + MATCH (a:base {{entity_id: sources[i]}})-[r]->(b:base {{entity_id: targets[i]}}) RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties $$) AS (source text, target text, edge_properties agtype)""" backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ WITH [{src_array}] AS sources, [{tgt_array}] AS targets UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}}) + MATCH (a:base {{entity_id: sources[i]}})<-[r]-(b:base {{entity_id: targets[i]}}) RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties $$) AS (source text, target text, edge_properties agtype)"""