diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 2e15fd3a..ad0e670b 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -244,6 +244,9 @@ def parse_args() -> argparse.Namespace: # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) + # Get MAX_GRAPH_NODES from environment + args.max_graph_nodes = get_env_value("MAX_GRAPH_NODES", 1000, int) + # Handle openai-ollama special case if args.llm_binding == "openai-ollama": args.llm_binding = "openai" diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 7a34ab5c..cd87af22 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -326,6 +326,7 @@ def create_app(args): enable_llm_cache=args.enable_llm_cache, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, + max_graph_nodes=args.max_graph_nodes, addon_params={"language": args.summary_language}, ) else: # azure_openai @@ -353,6 +354,7 @@ def create_app(args): enable_llm_cache=args.enable_llm_cache, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, + max_graph_nodes=args.max_graph_nodes, addon_params={"language": args.summary_language}, ) @@ -475,7 +477,7 @@ def create_app(args): "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, "workspace": args.workspace, - "max_graph_nodes": os.getenv("MAX_GRAPH_NODES"), + "max_graph_nodes": args.max_graph_nodes, }, "auth_mode": auth_mode, "pipeline_busy": pipeline_status.get("busy", False), diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index ace28432..dcf99327 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -34,8 +34,6 @@ from pymongo.errors import PyMongoError # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") -# Get maximum number of graph nodes from environment variable, default is 1000 -MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional") @@ -883,7 +881,7 @@ class MongoGraphStorage(BaseGraphStorage): ) async def get_knowledge_graph_all_by_degree( - self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + self, max_depth: int, max_nodes: int ) -> KnowledgeGraph: """ It's possible that the node with one or multiple relationships is retrieved, @@ -961,9 +959,9 @@ class MongoGraphStorage(BaseGraphStorage): node_labels: list[str], seen_nodes: set[str], result: KnowledgeGraph, - depth: int = 0, - max_depth: int = 3, - max_nodes: int = MAX_GRAPH_NODES, + depth: int, + max_depth: int, + max_nodes: int, ) -> KnowledgeGraph: if depth > max_depth or len(result.nodes) > max_nodes: return result @@ -1006,9 +1004,9 @@ class MongoGraphStorage(BaseGraphStorage): async def get_knowledge_subgraph_bidirectional_bfs( self, node_label: str, - depth=0, - max_depth: int = 3, - max_nodes: int = MAX_GRAPH_NODES, + depth: int, + max_depth: int, + max_nodes: int, ) -> KnowledgeGraph: seen_nodes = set() seen_edges = set() @@ -1038,7 +1036,7 @@ class MongoGraphStorage(BaseGraphStorage): return result async def get_knowledge_subgraph_in_out_bound_bfs( - self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: seen_nodes = set() seen_edges = set() @@ -1152,7 +1150,7 @@ class MongoGraphStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 3, - max_nodes: int = MAX_GRAPH_NODES, + max_nodes: int = None, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. @@ -1160,7 +1158,7 @@ class MongoGraphStorage(BaseGraphStorage): Args: node_label: Label of the starting node, * means all nodes max_depth: Maximum depth of the subgraph, Defaults to 3 - max_nodes: Maxiumu nodes to return, Defaults to 1000 + max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes Returns: KnowledgeGraph object containing nodes and edges, with an is_truncated flag @@ -1184,6 +1182,13 @@ class MongoGraphStorage(BaseGraphStorage): C → B C → D """ + # Use global_config max_graph_nodes as default if max_nodes is None + if max_nodes is None: + max_nodes = self.global_config.get("max_graph_nodes", 1000) + else: + # Limit max_nodes to not exceed global_config max_graph_nodes + max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) + result = KnowledgeGraph() start = time.perf_counter() diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 63a3300a..b0efc4b5 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -36,9 +36,6 @@ from dotenv import load_dotenv # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) -# Get maximum number of graph nodes from environment variable, default is 1000 -MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -902,7 +899,7 @@ class Neo4JStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 3, - max_nodes: int = MAX_GRAPH_NODES, + max_nodes: int = None, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. @@ -916,6 +913,13 @@ class Neo4JStorage(BaseGraphStorage): KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ + # Get max_nodes from global_config if not provided + if max_nodes is None: + max_nodes = self.global_config.get("max_graph_nodes", 1000) + else: + # Limit max_nodes to not exceed global_config max_graph_nodes + max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) + workspace_label = self._get_workspace_label() result = KnowledgeGraph() seen_nodes = set() diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index faff6a96..01b34566 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -26,8 +26,6 @@ from dotenv import load_dotenv # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) -MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - @final @dataclass @@ -218,7 +216,7 @@ class NetworkXStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 3, - max_nodes: int = MAX_GRAPH_NODES, + max_nodes: int = None, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. @@ -232,6 +230,13 @@ class NetworkXStorage(BaseGraphStorage): KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ + # Get max_nodes from global_config if not provided + if max_nodes is None: + max_nodes = self.global_config.get("max_graph_nodes", 1000) + else: + # Limit max_nodes to not exceed global_config max_graph_nodes + max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) + graph = await self._get_graph() result = KnowledgeGraph() diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 5750fd6e..7e696c77 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -45,9 +45,6 @@ from dotenv import load_dotenv # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) -# Get maximum number of graph nodes from environment variable, default is 1000 -MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - class PostgreSQLDB: def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -2819,7 +2816,7 @@ class PGGraphStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 3, - max_nodes: int = MAX_GRAPH_NODES, + max_nodes: int = None, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. @@ -2827,12 +2824,18 @@ class PGGraphStorage(BaseGraphStorage): Args: node_label: Label of the starting node, * means all nodes max_depth: Maximum depth of the subgraph, Defaults to 3 - max_nodes: Maxiumu nodes to return, Defaults to 1000 + max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes Returns: KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ + # Use global_config max_graph_nodes as default if max_nodes is None + if max_nodes is None: + max_nodes = self.global_config.get("max_graph_nodes", 1000) + else: + # Limit max_nodes to not exceed global_config max_graph_nodes + max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) kg = KnowledgeGraph() # Handle wildcard query - get all nodes diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5d96aeba..cbb5e2a8 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -258,6 +258,9 @@ class LightRAG: max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2))) """Maximum number of parallel insert operations.""" + max_graph_nodes: int = field(default=get_env_value("MAX_GRAPH_NODES", 1000, int)) + """Maximum number of graph nodes to return in knowledge graph queries.""" + addon_params: dict[str, Any] = field( default_factory=lambda: { "language": get_env_value("SUMMARY_LANGUAGE", "English", str) @@ -526,18 +529,24 @@ class LightRAG: self, node_label: str, max_depth: int = 3, - max_nodes: int = 1000, + max_nodes: int = None, ) -> KnowledgeGraph: """Get knowledge graph for a given label Args: node_label (str): Label to get knowledge graph for max_depth (int): Maximum depth of graph - max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000. + max_nodes (int, optional): Maximum number of nodes to return. Defaults to self.max_graph_nodes. Returns: KnowledgeGraph: Knowledge graph containing nodes and edges """ + # Use self.max_graph_nodes as default if max_nodes is None + if max_nodes is None: + max_nodes = self.max_graph_nodes + else: + # Limit max_nodes to not exceed self.max_graph_nodes + max_nodes = min(max_nodes, self.max_graph_nodes) return await self.chunk_entity_relation_graph.get_knowledge_graph( node_label, max_depth, max_nodes