feat: refactor ollama server configuration management

- Add ollama_server_infos attribute to LightRAG class with default initialization
- Move default values to constants.py for centralized configuration
- Refactor OllamaServerInfos class with property accessors and CLI support
- Update OllamaAPI to get configuration through rag object instead of direct import
- Add command line arguments for simulated model name and tag
- Fix type imports to avoid circular dependencies
This commit is contained in:
yangdx 2025-07-28 01:38:35 +08:00
parent 598eecd06d
commit f2ffff063b
6 changed files with 84 additions and 16 deletions

View file

@ -7,7 +7,6 @@ HOST=0.0.0.0
PORT=9621 PORT=9621
WEBUI_TITLE='My Graph KB' WEBUI_TITLE='My Graph KB'
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System" WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
OLLAMA_EMULATING_MODEL_TAG=latest
# WORKERS=2 # WORKERS=2
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080 # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
@ -21,6 +20,10 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# INPUT_DIR=<absolute_path_for_doc_input_dir> # INPUT_DIR=<absolute_path_for_doc_input_dir>
# WORKING_DIR=<absolute_path_for_working_dir> # WORKING_DIR=<absolute_path_for_working_dir>
### Ollama Emulating Model and Tag
# OLLAMA_EMULATING_MODEL_NAME=lightrag
OLLAMA_EMULATING_MODEL_TAG=latest
### Max nodes return from grap retrieval in webui ### Max nodes return from grap retrieval in webui
# MAX_GRAPH_NODES=1000 # MAX_GRAPH_NODES=1000

View file

@ -26,6 +26,11 @@ from lightrag.constants import (
DEFAULT_SUMMARY_LANGUAGE, DEFAULT_SUMMARY_LANGUAGE,
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, DEFAULT_EMBEDDING_FUNC_MAX_ASYNC,
DEFAULT_EMBEDDING_BATCH_NUM, DEFAULT_EMBEDDING_BATCH_NUM,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_OLLAMA_MODEL_SIZE,
DEFAULT_OLLAMA_CREATED_AT,
DEFAULT_OLLAMA_DIGEST,
) )
# use the .env that is inside the current folder # use the .env that is inside the current folder
@ -35,13 +40,36 @@ load_dotenv(dotenv_path=".env", override=False)
class OllamaServerInfos: class OllamaServerInfos:
# Constants for emulated Ollama model information def __init__(self, name=None, tag=None):
LIGHTRAG_NAME = "lightrag" self._lightrag_name = name or os.getenv(
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") "OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" )
LIGHTRAG_SIZE = 7365960935 # it's a dummy value self._lightrag_tag = tag or os.getenv(
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" "OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG
LIGHTRAG_DIGEST = "sha256:lightrag" )
self.LIGHTRAG_SIZE = DEFAULT_OLLAMA_MODEL_SIZE
self.LIGHTRAG_CREATED_AT = DEFAULT_OLLAMA_CREATED_AT
self.LIGHTRAG_DIGEST = DEFAULT_OLLAMA_DIGEST
@property
def LIGHTRAG_NAME(self):
return self._lightrag_name
@LIGHTRAG_NAME.setter
def LIGHTRAG_NAME(self, value):
self._lightrag_name = value
@property
def LIGHTRAG_TAG(self):
return self._lightrag_tag
@LIGHTRAG_TAG.setter
def LIGHTRAG_TAG(self, value):
self._lightrag_tag = value
@property
def LIGHTRAG_MODEL(self):
return f"{self._lightrag_name}:{self._lightrag_tag}"
ollama_server_infos = OllamaServerInfos() ollama_server_infos = OllamaServerInfos()
@ -166,14 +194,19 @@ def parse_args() -> argparse.Namespace:
help="Path to SSL private key file (required if --ssl is enabled)", help="Path to SSL private key file (required if --ssl is enabled)",
) )
# Ollama model name # Ollama model configuration
parser.add_argument( parser.add_argument(
"--simulated-model-name", "--simulated-model-name",
type=str, type=str,
default=get_env_value( default=get_env_value("OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME),
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL help="Name for the simulated Ollama model (default: from env or lightrag)",
), )
help="Number of conversation history turns to include (default: from env or 3)",
parser.add_argument(
"--simulated-model-tag",
type=str,
default=get_env_value("OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG),
help="Tag for the simulated Ollama model (default: from env or latest)",
) )
# Namespace # Namespace
@ -333,7 +366,8 @@ def parse_args() -> argparse.Namespace:
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
) )
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
return args return args

View file

@ -335,6 +335,13 @@ def create_app(args):
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
) )
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(
name=args.simulated_model_name, tag=args.simulated_model_tag
)
# Initialize RAG # Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai"]: if args.llm_binding in ["lollms", "ollama", "openai"]:
rag = LightRAG( rag = LightRAG(
@ -373,6 +380,7 @@ def create_app(args):
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes, max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language}, addon_params={"language": args.summary_language},
ollama_server_infos=ollama_server_infos,
) )
else: # azure_openai else: # azure_openai
rag = LightRAG( rag = LightRAG(
@ -402,6 +410,7 @@ def create_app(args):
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes, max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language}, addon_params={"language": args.summary_language},
ollama_server_infos=ollama_server_infos,
) )
# Add routes # Add routes

View file

@ -11,7 +11,7 @@ import asyncio
from ascii_colors import trace_exception from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import TiktokenTokenizer from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
from fastapi import Depends from fastapi import Depends
@ -221,7 +221,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
class OllamaAPI: class OllamaAPI:
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None): def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
self.rag = rag self.rag = rag
self.ollama_server_infos = ollama_server_infos self.ollama_server_infos = rag.ollama_server_infos
self.top_k = top_k self.top_k = top_k
self.api_key = api_key self.api_key = api_key
self.router = APIRouter(tags=["ollama"]) self.router = APIRouter(tags=["ollama"])

View file

@ -47,3 +47,10 @@ DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computation
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
# Ollama server configuration defaults
DEFAULT_OLLAMA_MODEL_NAME = "lightrag"
DEFAULT_OLLAMA_MODEL_TAG = "latest"
DEFAULT_OLLAMA_MODEL_SIZE = 7365960935
DEFAULT_OLLAMA_CREATED_AT = "2024-01-15T00:00:00Z"
DEFAULT_OLLAMA_DIGEST = "sha256:lightrag"

View file

@ -41,6 +41,12 @@ from lightrag.kg import (
verify_storage_implementation, verify_storage_implementation,
) )
# Import for type annotation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from lightrag.api.config import OllamaServerInfos
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_pipeline_status_lock, get_pipeline_status_lock,
@ -342,6 +348,9 @@ class LightRAG:
default=float(os.getenv("COSINE_THRESHOLD", 0.2)) default=float(os.getenv("COSINE_THRESHOLD", 0.2))
) )
ollama_server_infos: Optional["OllamaServerInfos"] = field(default=None)
"""Configuration for Ollama server information."""
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self): def __post_init__(self):
@ -403,6 +412,12 @@ class LightRAG:
else: else:
self.tokenizer = TiktokenTokenizer() self.tokenizer = TiktokenTokenizer()
# Initialize ollama_server_infos if not provided
if self.ollama_server_infos is None:
from lightrag.api.config import OllamaServerInfos
self.ollama_server_infos = OllamaServerInfos()
# Fix global_config now # Fix global_config now
global_config = asdict(self) global_config = asdict(self)