Move max_graph_nodes to global config

This commit is contained in:
yangdx 2025-07-07 21:53:57 +08:00
parent cb14ce6ff3
commit ef79088f60
7 changed files with 58 additions and 27 deletions

View file

@ -244,6 +244,9 @@ def parse_args() -> argparse.Namespace:
# Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
# Get MAX_GRAPH_NODES from environment
args.max_graph_nodes = get_env_value("MAX_GRAPH_NODES", 1000, int)
# Handle openai-ollama special case
if args.llm_binding == "openai-ollama":
args.llm_binding = "openai"

View file

@ -326,6 +326,7 @@ def create_app(args):
enable_llm_cache=args.enable_llm_cache,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language},
)
else: # azure_openai
@ -353,6 +354,7 @@ def create_app(args):
enable_llm_cache=args.enable_llm_cache,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language},
)
@ -475,7 +477,7 @@ def create_app(args):
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace,
"max_graph_nodes": os.getenv("MAX_GRAPH_NODES"),
"max_graph_nodes": args.max_graph_nodes,
},
"auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False),

View file

@ -34,8 +34,6 @@ from pymongo.errors import PyMongoError # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
@ -883,7 +881,7 @@ class MongoGraphStorage(BaseGraphStorage):
)
async def get_knowledge_graph_all_by_degree(
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
self, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
"""
It's possible that the node with one or multiple relationships is retrieved,
@ -961,9 +959,9 @@ class MongoGraphStorage(BaseGraphStorage):
node_labels: list[str],
seen_nodes: set[str],
result: KnowledgeGraph,
depth: int = 0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
depth: int,
max_depth: int,
max_nodes: int,
) -> KnowledgeGraph:
if depth > max_depth or len(result.nodes) > max_nodes:
return result
@ -1006,9 +1004,9 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_knowledge_subgraph_bidirectional_bfs(
self,
node_label: str,
depth=0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
depth: int,
max_depth: int,
max_nodes: int,
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
@ -1038,7 +1036,7 @@ class MongoGraphStorage(BaseGraphStorage):
return result
async def get_knowledge_subgraph_in_out_bound_bfs(
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
@ -1152,7 +1150,7 @@ class MongoGraphStorage(BaseGraphStorage):
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
max_nodes: int = None,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -1160,7 +1158,7 @@ class MongoGraphStorage(BaseGraphStorage):
Args:
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000
max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
@ -1184,6 +1182,13 @@ class MongoGraphStorage(BaseGraphStorage):
C B
C D
"""
# Use global_config max_graph_nodes as default if max_nodes is None
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
result = KnowledgeGraph()
start = time.perf_counter()

View file

@ -36,9 +36,6 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@ -902,7 +899,7 @@ class Neo4JStorage(BaseGraphStorage):
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
max_nodes: int = None,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -916,6 +913,13 @@ class Neo4JStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# Get max_nodes from global_config if not provided
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
workspace_label = self._get_workspace_label()
result = KnowledgeGraph()
seen_nodes = set()

View file

@ -26,8 +26,6 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final
@dataclass
@ -218,7 +216,7 @@ class NetworkXStorage(BaseGraphStorage):
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
max_nodes: int = None,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -232,6 +230,13 @@ class NetworkXStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# Get max_nodes from global_config if not provided
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
graph = await self._get_graph()
result = KnowledgeGraph()

View file

@ -45,9 +45,6 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any):
@ -2819,7 +2816,7 @@ class PGGraphStorage(BaseGraphStorage):
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
max_nodes: int = None,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@ -2827,12 +2824,18 @@ class PGGraphStorage(BaseGraphStorage):
Args:
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000
max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# Use global_config max_graph_nodes as default if max_nodes is None
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
kg = KnowledgeGraph()
# Handle wildcard query - get all nodes

View file

@ -258,6 +258,9 @@ class LightRAG:
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2)))
"""Maximum number of parallel insert operations."""
max_graph_nodes: int = field(default=get_env_value("MAX_GRAPH_NODES", 1000, int))
"""Maximum number of graph nodes to return in knowledge graph queries."""
addon_params: dict[str, Any] = field(
default_factory=lambda: {
"language": get_env_value("SUMMARY_LANGUAGE", "English", str)
@ -526,18 +529,24 @@ class LightRAG:
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = 1000,
max_nodes: int = None,
) -> KnowledgeGraph:
"""Get knowledge graph for a given label
Args:
node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
max_nodes (int, optional): Maximum number of nodes to return. Defaults to self.max_graph_nodes.
Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges
"""
# Use self.max_graph_nodes as default if max_nodes is None
if max_nodes is None:
max_nodes = self.max_graph_nodes
else:
# Limit max_nodes to not exceed self.max_graph_nodes
max_nodes = min(max_nodes, self.max_graph_nodes)
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label, max_depth, max_nodes