Move max_graph_nodes to global config

This commit is contained in:
yangdx 2025-07-07 21:53:57 +08:00
parent cb14ce6ff3
commit ef79088f60
7 changed files with 58 additions and 27 deletions

View file

@ -244,6 +244,9 @@ def parse_args() -> argparse.Namespace:
# Get MAX_PARALLEL_INSERT from environment # Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
# Get MAX_GRAPH_NODES from environment
args.max_graph_nodes = get_env_value("MAX_GRAPH_NODES", 1000, int)
# Handle openai-ollama special case # Handle openai-ollama special case
if args.llm_binding == "openai-ollama": if args.llm_binding == "openai-ollama":
args.llm_binding = "openai" args.llm_binding = "openai"

View file

@ -326,6 +326,7 @@ def create_app(args):
enable_llm_cache=args.enable_llm_cache, enable_llm_cache=args.enable_llm_cache,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language}, addon_params={"language": args.summary_language},
) )
else: # azure_openai else: # azure_openai
@ -353,6 +354,7 @@ def create_app(args):
enable_llm_cache=args.enable_llm_cache, enable_llm_cache=args.enable_llm_cache,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language}, addon_params={"language": args.summary_language},
) )
@ -475,7 +477,7 @@ def create_app(args):
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace, "workspace": args.workspace,
"max_graph_nodes": os.getenv("MAX_GRAPH_NODES"), "max_graph_nodes": args.max_graph_nodes,
}, },
"auth_mode": auth_mode, "auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False), "pipeline_busy": pipeline_status.get("busy", False),

View file

@ -34,8 +34,6 @@ from pymongo.errors import PyMongoError # type: ignore
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional") GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
@ -883,7 +881,7 @@ class MongoGraphStorage(BaseGraphStorage):
) )
async def get_knowledge_graph_all_by_degree( async def get_knowledge_graph_all_by_degree(
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES self, max_depth: int, max_nodes: int
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
It's possible that the node with one or multiple relationships is retrieved, It's possible that the node with one or multiple relationships is retrieved,
@ -961,9 +959,9 @@ class MongoGraphStorage(BaseGraphStorage):
node_labels: list[str], node_labels: list[str],
seen_nodes: set[str], seen_nodes: set[str],
result: KnowledgeGraph, result: KnowledgeGraph,
depth: int = 0, depth: int,
max_depth: int = 3, max_depth: int,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
if depth > max_depth or len(result.nodes) > max_nodes: if depth > max_depth or len(result.nodes) > max_nodes:
return result return result
@ -1006,9 +1004,9 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_knowledge_subgraph_bidirectional_bfs( async def get_knowledge_subgraph_bidirectional_bfs(
self, self,
node_label: str, node_label: str,
depth=0, depth: int,
max_depth: int = 3, max_depth: int,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
@ -1038,7 +1036,7 @@ class MongoGraphStorage(BaseGraphStorage):
return result return result
async def get_knowledge_subgraph_in_out_bound_bfs( async def get_knowledge_subgraph_in_out_bound_bfs(
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph: ) -> KnowledgeGraph:
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
@ -1152,7 +1150,7 @@ class MongoGraphStorage(BaseGraphStorage):
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int = None,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -1160,7 +1158,7 @@ class MongoGraphStorage(BaseGraphStorage):
Args: Args:
node_label: Label of the starting node, * means all nodes node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3 max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000 max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
Returns: Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag KnowledgeGraph object containing nodes and edges, with an is_truncated flag
@ -1184,6 +1182,13 @@ class MongoGraphStorage(BaseGraphStorage):
C B C B
C D C D
""" """
# Use global_config max_graph_nodes as default if max_nodes is None
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
result = KnowledgeGraph() result = KnowledgeGraph()
start = time.perf_counter() start = time.perf_counter()

View file

@ -36,9 +36,6 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@ -902,7 +899,7 @@ class Neo4JStorage(BaseGraphStorage):
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int = None,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -916,6 +913,13 @@ class Neo4JStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit indicating whether the graph was truncated due to max_nodes limit
""" """
# Get max_nodes from global_config if not provided
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()

View file

@ -26,8 +26,6 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final @final
@dataclass @dataclass
@ -218,7 +216,7 @@ class NetworkXStorage(BaseGraphStorage):
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int = None,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -232,6 +230,13 @@ class NetworkXStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit indicating whether the graph was truncated due to max_nodes limit
""" """
# Get max_nodes from global_config if not provided
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
graph = await self._get_graph() graph = await self._get_graph()
result = KnowledgeGraph() result = KnowledgeGraph()

View file

@ -45,9 +45,6 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
class PostgreSQLDB: class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any): def __init__(self, config: dict[str, Any], **kwargs: Any):
@ -2819,7 +2816,7 @@ class PGGraphStorage(BaseGraphStorage):
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int = None,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -2827,12 +2824,18 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_label: Label of the starting node, * means all nodes node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3 max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000 max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
Returns: Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit indicating whether the graph was truncated due to max_nodes limit
""" """
# Use global_config max_graph_nodes as default if max_nodes is None
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
kg = KnowledgeGraph() kg = KnowledgeGraph()
# Handle wildcard query - get all nodes # Handle wildcard query - get all nodes

View file

@ -258,6 +258,9 @@ class LightRAG:
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2))) max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2)))
"""Maximum number of parallel insert operations.""" """Maximum number of parallel insert operations."""
max_graph_nodes: int = field(default=get_env_value("MAX_GRAPH_NODES", 1000, int))
"""Maximum number of graph nodes to return in knowledge graph queries."""
addon_params: dict[str, Any] = field( addon_params: dict[str, Any] = field(
default_factory=lambda: { default_factory=lambda: {
"language": get_env_value("SUMMARY_LANGUAGE", "English", str) "language": get_env_value("SUMMARY_LANGUAGE", "English", str)
@ -526,18 +529,24 @@ class LightRAG:
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
max_nodes: int = 1000, max_nodes: int = None,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Get knowledge graph for a given label """Get knowledge graph for a given label
Args: Args:
node_label (str): Label to get knowledge graph for node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph max_depth (int): Maximum depth of graph
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000. max_nodes (int, optional): Maximum number of nodes to return. Defaults to self.max_graph_nodes.
Returns: Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges KnowledgeGraph: Knowledge graph containing nodes and edges
""" """
# Use self.max_graph_nodes as default if max_nodes is None
if max_nodes is None:
max_nodes = self.max_graph_nodes
else:
# Limit max_nodes to not exceed self.max_graph_nodes
max_nodes = min(max_nodes, self.max_graph_nodes)
return await self.chunk_entity_relation_graph.get_knowledge_graph( return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label, max_depth, max_nodes node_label, max_depth, max_nodes