diff --git a/lightrag/base.py b/lightrag/base.py index 3cf40136..bae0728b 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -19,7 +19,6 @@ from typing import ( from .utils import EmbeddingFunc from .types import KnowledgeGraph from .constants import ( - GRAPH_FIELD_SEP, DEFAULT_TOP_K, DEFAULT_CHUNK_TOP_K, DEFAULT_MAX_ENTITY_TOKENS, @@ -528,56 +527,6 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[node_id] = edges if edges is not None else [] return result - @abstractmethod - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all nodes that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. - - Returns: - list[dict]: A list of nodes, where each node is a dictionary of its properties. - An empty list if no matching nodes are found. - """ - - @abstractmethod - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all edges that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated edges for. - - Returns: - list[dict]: A list of edges, where each edge is a dictionary of its properties. - An empty list if no matching edges are found. - """ - # Default implementation iterates through all nodes and their edges, which is inefficient. - # This method should be overridden by subclasses for better performance. - all_edges = [] - all_labels = await self.get_all_labels() - processed_edges = set() - - for label in all_labels: - edges = await self.get_node_edges(label) - if edges: - for src_id, tgt_id in edges: - # Avoid processing the same edge twice in an undirected graph - edge_tuple = tuple(sorted((src_id, tgt_id))) - if edge_tuple in processed_edges: - continue - processed_edges.add(edge_tuple) - - edge = await self.get_edge(src_id, tgt_id) - if edge and "source_id" in edge: - source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP)) - if not source_ids.isdisjoint(chunk_ids): - # Add source and target to the edge dict for easier processing later - edge_with_nodes = edge.copy() - edge_with_nodes["source"] = src_id - edge_with_nodes["target"] = tgt_id - all_edges.append(edge_with_nodes) - return all_edges - @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Insert a new node or update an existing node in the graph. diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index f48ea20f..d81c2ebd 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -8,7 +8,6 @@ import configparser from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm @@ -784,79 +783,6 @@ class MemgraphStorage(BaseGraphStorage): degrees = int(src_degree) + int(trg_degree) return degrees - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all nodes that are associated with the given chunk_ids. - - Args: - chunk_ids: List of chunk IDs to find associated nodes for - - Returns: - list[dict]: A list of nodes, where each node is a dictionary of its properties. - An empty list if no matching nodes are found. - """ - if self._driver is None: - raise RuntimeError( - "Memgraph driver is not initialized. Call 'await initialize()' first." - ) - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (n:`{workspace_label}`) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) - RETURN DISTINCT n - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - nodes = [] - async for record in result: - node = record["n"] - node_dict = dict(node) - node_dict["id"] = node_dict.get("entity_id") - nodes.append(node_dict) - await result.consume() - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all edges that are associated with the given chunk_ids. - - Args: - chunk_ids: List of chunk IDs to find associated edges for - - Returns: - list[dict]: A list of edges, where each edge is a dictionary of its properties. - An empty list if no matching edges are found. - """ - if self._driver is None: - raise RuntimeError( - "Memgraph driver is not initialized. Call 'await initialize()' first." - ) - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) - WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id - // Ensure we only return each unique edge once by ordering the source and target - WITH a, b, r, - CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source, - CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target - RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - edges = [] - async for record in result: - edge_properties = record["properties"] - edge_properties["source"] = record["source"] - edge_properties["target"] = record["target"] - edges.append(edge_properties) - await result.consume() - return edges - async def get_knowledge_graph( self, node_label: str, diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index e55062f1..30452c74 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1036,45 +1036,6 @@ class MongoGraphStorage(BaseGraphStorage): return result - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all nodes that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. - - Returns: - list[dict]: A list of nodes, where each node is a dictionary of its properties. - An empty list if no matching nodes are found. - """ - if not chunk_ids: - return [] - - cursor = self.collection.find({"source_ids": {"$in": chunk_ids}}) - return [doc async for doc in cursor] - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all edges that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated edges for. - - Returns: - list[dict]: A list of edges, where each edge is a dictionary of its properties. - An empty list if no matching edges are found. - """ - if not chunk_ids: - return [] - - cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}}) - - edges = [] - async for edge in cursor: - edge["source"] = edge["source_node_id"] - edge["target"] = edge["target_node_id"] - edges.append(edge) - - return edges - # # ------------------------------------------------------------------------- # UPSERTS diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 896e5973..76fa11f2 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -16,7 +16,6 @@ import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm @@ -904,49 +903,6 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are fully consumed return edges_dict - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (n:`{workspace_label}`) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) - RETURN DISTINCT n - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - nodes = [] - async for record in result: - node = record["n"] - node_dict = dict(node) - # Add node id (entity_id) to the dictionary for easier access - node_dict["id"] = node_dict.get("entity_id") - nodes.append(node_dict) - await result.consume() - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) - RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - edges = [] - async for record in result: - edge_properties = record["properties"] - edge_properties["source"] = record["source"] - edge_properties["target"] = record["target"] - edges.append(edge_properties) - await result.consume() - return edges - @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 88a182d6..48a2d2af 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -5,7 +5,6 @@ from typing import final from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage -from lightrag.constants import GRAPH_FIELD_SEP import networkx as nx from .shared_storage import ( get_storage_lock, @@ -470,33 +469,6 @@ class NetworkXStorage(BaseGraphStorage): ) return result - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - chunk_ids_set = set(chunk_ids) - graph = await self._get_graph() - matching_nodes = [] - for node_id, node_data in graph.nodes(data=True): - if "source_id" in node_data: - node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) - if not node_source_ids.isdisjoint(chunk_ids_set): - node_data_with_id = node_data.copy() - node_data_with_id["id"] = node_id - matching_nodes.append(node_data_with_id) - return matching_nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - chunk_ids_set = set(chunk_ids) - graph = await self._get_graph() - matching_edges = [] - for u, v, edge_data in graph.edges(data=True): - if "source_id" in edge_data: - edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) - if not edge_source_ids.isdisjoint(chunk_ids_set): - edge_data_with_nodes = edge_data.copy() - edge_data_with_nodes["source"] = u - edge_data_with_nodes["target"] = v - matching_edges.append(edge_data_with_nodes) - return matching_edges - async def get_all_nodes(self) -> list[dict]: """Get all nodes in the graph. diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 723de69f..2a7c6158 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -33,7 +33,6 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger -from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock import pipmaster as pm @@ -4175,102 +4174,6 @@ class PGGraphStorage(BaseGraphStorage): labels.append(result["label"]) return labels - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves nodes from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all nodes - where the `source_id` property contains any of the specified chunk IDs. - """ - # The string representation of the list for the cypher query - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH (n:base) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}') - RETURN n - $$) AS (n agtype); - """ - results = await self._query(query) - - # Build result list - nodes = [] - for result in results: - if result["n"]: - node_dict = result["n"]["properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" - ) - - node_dict["id"] = node_dict["entity_id"] - nodes.append(node_dict) - - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves edges from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all edges - where the `source_id` property contains any of the specified chunk IDs. - """ - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH ()-[r]-() - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}') - RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target - $$) AS (edge agtype, source agtype, target agtype); - """ - results = await self._query(query) - edges = [] - if results: - for item in results: - edge_agtype = item["edge"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(edge_agtype, str): - try: - edge_agtype = json.loads(edge_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}" - ) - - source_agtype = item["source"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(source_agtype, str): - try: - source_agtype = json.loads(source_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}" - ) - - target_agtype = item["target"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(target_agtype, str): - try: - target_agtype = json.loads(target_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}" - ) - - if edge_agtype and source_agtype and target_agtype: - edge_properties = edge_agtype - edge_properties["source"] = source_agtype["entity_id"] - edge_properties["target"] = target_agtype["entity_id"] - edges.append(edge_properties) - return edges - async def _bfs_subgraph( self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: