Move max_graph_nodes to global config
This commit is contained in:
parent
cb14ce6ff3
commit
ef79088f60
7 changed files with 58 additions and 27 deletions
|
|
@ -244,6 +244,9 @@ def parse_args() -> argparse.Namespace:
|
|||
# Get MAX_PARALLEL_INSERT from environment
|
||||
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
|
||||
# Get MAX_GRAPH_NODES from environment
|
||||
args.max_graph_nodes = get_env_value("MAX_GRAPH_NODES", 1000, int)
|
||||
|
||||
# Handle openai-ollama special case
|
||||
if args.llm_binding == "openai-ollama":
|
||||
args.llm_binding = "openai"
|
||||
|
|
|
|||
|
|
@ -326,6 +326,7 @@ def create_app(args):
|
|||
enable_llm_cache=args.enable_llm_cache,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
addon_params={"language": args.summary_language},
|
||||
)
|
||||
else: # azure_openai
|
||||
|
|
@ -353,6 +354,7 @@ def create_app(args):
|
|||
enable_llm_cache=args.enable_llm_cache,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
addon_params={"language": args.summary_language},
|
||||
)
|
||||
|
||||
|
|
@ -475,7 +477,7 @@ def create_app(args):
|
|||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
"enable_llm_cache": args.enable_llm_cache,
|
||||
"workspace": args.workspace,
|
||||
"max_graph_nodes": os.getenv("MAX_GRAPH_NODES"),
|
||||
"max_graph_nodes": args.max_graph_nodes,
|
||||
},
|
||||
"auth_mode": auth_mode,
|
||||
"pipeline_busy": pipeline_status.get("busy", False),
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@ from pymongo.errors import PyMongoError # type: ignore
|
|||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
|
||||
|
||||
|
||||
|
|
@ -883,7 +881,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
)
|
||||
|
||||
async def get_knowledge_graph_all_by_degree(
|
||||
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||
self, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
It's possible that the node with one or multiple relationships is retrieved,
|
||||
|
|
@ -961,9 +959,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
node_labels: list[str],
|
||||
seen_nodes: set[str],
|
||||
result: KnowledgeGraph,
|
||||
depth: int = 0,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
max_nodes: int,
|
||||
) -> KnowledgeGraph:
|
||||
if depth > max_depth or len(result.nodes) > max_nodes:
|
||||
return result
|
||||
|
|
@ -1006,9 +1004,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
async def get_knowledge_subgraph_bidirectional_bfs(
|
||||
self,
|
||||
node_label: str,
|
||||
depth=0,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
max_nodes: int,
|
||||
) -> KnowledgeGraph:
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
|
@ -1038,7 +1036,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
return result
|
||||
|
||||
async def get_knowledge_subgraph_in_out_bound_bfs(
|
||||
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
|
@ -1152,7 +1150,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
max_nodes: int = None,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
|
@ -1160,7 +1158,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
Args:
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000
|
||||
max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
|
|
@ -1184,6 +1182,13 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
C → B
|
||||
C → D
|
||||
"""
|
||||
# Use global_config max_graph_nodes as default if max_nodes is None
|
||||
if max_nodes is None:
|
||||
max_nodes = self.global_config.get("max_graph_nodes", 1000)
|
||||
else:
|
||||
# Limit max_nodes to not exceed global_config max_graph_nodes
|
||||
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
|
||||
|
||||
result = KnowledgeGraph()
|
||||
start = time.perf_counter()
|
||||
|
||||
|
|
|
|||
|
|
@ -36,9 +36,6 @@ from dotenv import load_dotenv
|
|||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
|
@ -902,7 +899,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
max_nodes: int = None,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
|
@ -916,6 +913,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
# Get max_nodes from global_config if not provided
|
||||
if max_nodes is None:
|
||||
max_nodes = self.global_config.get("max_graph_nodes", 1000)
|
||||
else:
|
||||
# Limit max_nodes to not exceed global_config max_graph_nodes
|
||||
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
|
||||
|
||||
workspace_label = self._get_workspace_label()
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
|
|
|
|||
|
|
@ -26,8 +26,6 @@ from dotenv import load_dotenv
|
|||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
|
@ -218,7 +216,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
max_nodes: int = None,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
|
@ -232,6 +230,13 @@ class NetworkXStorage(BaseGraphStorage):
|
|||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
# Get max_nodes from global_config if not provided
|
||||
if max_nodes is None:
|
||||
max_nodes = self.global_config.get("max_graph_nodes", 1000)
|
||||
else:
|
||||
# Limit max_nodes to not exceed global_config max_graph_nodes
|
||||
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
result = KnowledgeGraph()
|
||||
|
|
|
|||
|
|
@ -45,9 +45,6 @@ from dotenv import load_dotenv
|
|||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
|
||||
|
||||
class PostgreSQLDB:
|
||||
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
||||
|
|
@ -2819,7 +2816,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
max_nodes: int = None,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
|
@ -2827,12 +2824,18 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
Args:
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000
|
||||
max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
# Use global_config max_graph_nodes as default if max_nodes is None
|
||||
if max_nodes is None:
|
||||
max_nodes = self.global_config.get("max_graph_nodes", 1000)
|
||||
else:
|
||||
# Limit max_nodes to not exceed global_config max_graph_nodes
|
||||
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
|
||||
kg = KnowledgeGraph()
|
||||
|
||||
# Handle wildcard query - get all nodes
|
||||
|
|
|
|||
|
|
@ -258,6 +258,9 @@ class LightRAG:
|
|||
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2)))
|
||||
"""Maximum number of parallel insert operations."""
|
||||
|
||||
max_graph_nodes: int = field(default=get_env_value("MAX_GRAPH_NODES", 1000, int))
|
||||
"""Maximum number of graph nodes to return in knowledge graph queries."""
|
||||
|
||||
addon_params: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"language": get_env_value("SUMMARY_LANGUAGE", "English", str)
|
||||
|
|
@ -526,18 +529,24 @@ class LightRAG:
|
|||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = 1000,
|
||||
max_nodes: int = None,
|
||||
) -> KnowledgeGraph:
|
||||
"""Get knowledge graph for a given label
|
||||
|
||||
Args:
|
||||
node_label (str): Label to get knowledge graph for
|
||||
max_depth (int): Maximum depth of graph
|
||||
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
|
||||
max_nodes (int, optional): Maximum number of nodes to return. Defaults to self.max_graph_nodes.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Knowledge graph containing nodes and edges
|
||||
"""
|
||||
# Use self.max_graph_nodes as default if max_nodes is None
|
||||
if max_nodes is None:
|
||||
max_nodes = self.max_graph_nodes
|
||||
else:
|
||||
# Limit max_nodes to not exceed self.max_graph_nodes
|
||||
max_nodes = min(max_nodes, self.max_graph_nodes)
|
||||
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||
node_label, max_depth, max_nodes
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue