feat: refactor ollama server configuration management

- Add ollama_server_infos attribute to LightRAG class with default initialization
- Move default values to constants.py for centralized configuration
- Refactor OllamaServerInfos class with property accessors and CLI support
- Update OllamaAPI to get configuration through rag object instead of direct import
- Add command line arguments for simulated model name and tag
- Fix type imports to avoid circular dependencies
This commit is contained in:
yangdx 2025-07-28 01:38:35 +08:00
parent 598eecd06d
commit f2ffff063b
6 changed files with 84 additions and 16 deletions

View file

@ -7,7 +7,6 @@ HOST=0.0.0.0
PORT=9621
WEBUI_TITLE='My Graph KB'
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
OLLAMA_EMULATING_MODEL_TAG=latest
# WORKERS=2
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
@ -21,6 +20,10 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# INPUT_DIR=<absolute_path_for_doc_input_dir>
# WORKING_DIR=<absolute_path_for_working_dir>
### Ollama Emulating Model and Tag
# OLLAMA_EMULATING_MODEL_NAME=lightrag
OLLAMA_EMULATING_MODEL_TAG=latest
### Max nodes return from grap retrieval in webui
# MAX_GRAPH_NODES=1000

View file

@ -26,6 +26,11 @@ from lightrag.constants import (
DEFAULT_SUMMARY_LANGUAGE,
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC,
DEFAULT_EMBEDDING_BATCH_NUM,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_OLLAMA_MODEL_SIZE,
DEFAULT_OLLAMA_CREATED_AT,
DEFAULT_OLLAMA_DIGEST,
)
# use the .env that is inside the current folder
@ -35,13 +40,36 @@ load_dotenv(dotenv_path=".env", override=False)
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
def __init__(self, name=None, tag=None):
self._lightrag_name = name or os.getenv(
"OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME
)
self._lightrag_tag = tag or os.getenv(
"OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG
)
self.LIGHTRAG_SIZE = DEFAULT_OLLAMA_MODEL_SIZE
self.LIGHTRAG_CREATED_AT = DEFAULT_OLLAMA_CREATED_AT
self.LIGHTRAG_DIGEST = DEFAULT_OLLAMA_DIGEST
@property
def LIGHTRAG_NAME(self):
return self._lightrag_name
@LIGHTRAG_NAME.setter
def LIGHTRAG_NAME(self, value):
self._lightrag_name = value
@property
def LIGHTRAG_TAG(self):
return self._lightrag_tag
@LIGHTRAG_TAG.setter
def LIGHTRAG_TAG(self, value):
self._lightrag_tag = value
@property
def LIGHTRAG_MODEL(self):
return f"{self._lightrag_name}:{self._lightrag_tag}"
ollama_server_infos = OllamaServerInfos()
@ -166,14 +194,19 @@ def parse_args() -> argparse.Namespace:
help="Path to SSL private key file (required if --ssl is enabled)",
)
# Ollama model name
# Ollama model configuration
parser.add_argument(
"--simulated-model-name",
type=str,
default=get_env_value(
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
),
help="Number of conversation history turns to include (default: from env or 3)",
default=get_env_value("OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME),
help="Name for the simulated Ollama model (default: from env or lightrag)",
)
parser.add_argument(
"--simulated-model-tag",
type=str,
default=get_env_value("OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG),
help="Tag for the simulated Ollama model (default: from env or latest)",
)
# Namespace
@ -333,7 +366,8 @@ def parse_args() -> argparse.Namespace:
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
return args

View file

@ -335,6 +335,13 @@ def create_app(args):
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
)
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(
name=args.simulated_model_name, tag=args.simulated_model_tag
)
# Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai"]:
rag = LightRAG(
@ -373,6 +380,7 @@ def create_app(args):
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language},
ollama_server_infos=ollama_server_infos,
)
else: # azure_openai
rag = LightRAG(
@ -402,6 +410,7 @@ def create_app(args):
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language},
ollama_server_infos=ollama_server_infos,
)
# Add routes

View file

@ -11,7 +11,7 @@ import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam
from lightrag.utils import TiktokenTokenizer
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
from lightrag.api.utils_api import get_combined_auth_dependency
from fastapi import Depends
@ -221,7 +221,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
class OllamaAPI:
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
self.rag = rag
self.ollama_server_infos = ollama_server_infos
self.ollama_server_infos = rag.ollama_server_infos
self.top_k = top_k
self.api_key = api_key
self.router = APIRouter(tags=["ollama"])

View file

@ -47,3 +47,10 @@ DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computation
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
# Ollama server configuration defaults
DEFAULT_OLLAMA_MODEL_NAME = "lightrag"
DEFAULT_OLLAMA_MODEL_TAG = "latest"
DEFAULT_OLLAMA_MODEL_SIZE = 7365960935
DEFAULT_OLLAMA_CREATED_AT = "2024-01-15T00:00:00Z"
DEFAULT_OLLAMA_DIGEST = "sha256:lightrag"

View file

@ -41,6 +41,12 @@ from lightrag.kg import (
verify_storage_implementation,
)
# Import for type annotation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from lightrag.api.config import OllamaServerInfos
from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
@ -342,6 +348,9 @@ class LightRAG:
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
)
ollama_server_infos: Optional["OllamaServerInfos"] = field(default=None)
"""Configuration for Ollama server information."""
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self):
@ -403,6 +412,12 @@ class LightRAG:
else:
self.tokenizer = TiktokenTokenizer()
# Initialize ollama_server_infos if not provided
if self.ollama_server_infos is None:
from lightrag.api.config import OllamaServerInfos
self.ollama_server_infos = OllamaServerInfos()
# Fix global_config now
global_config = asdict(self)