feat: refactor ollama server configuration management
- Add ollama_server_infos attribute to LightRAG class with default initialization - Move default values to constants.py for centralized configuration - Refactor OllamaServerInfos class with property accessors and CLI support - Update OllamaAPI to get configuration through rag object instead of direct import - Add command line arguments for simulated model name and tag - Fix type imports to avoid circular dependencies
This commit is contained in:
parent
598eecd06d
commit
f2ffff063b
6 changed files with 84 additions and 16 deletions
|
|
@ -7,7 +7,6 @@ HOST=0.0.0.0
|
|||
PORT=9621
|
||||
WEBUI_TITLE='My Graph KB'
|
||||
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
||||
OLLAMA_EMULATING_MODEL_TAG=latest
|
||||
# WORKERS=2
|
||||
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
||||
|
||||
|
|
@ -21,6 +20,10 @@ OLLAMA_EMULATING_MODEL_TAG=latest
|
|||
# INPUT_DIR=<absolute_path_for_doc_input_dir>
|
||||
# WORKING_DIR=<absolute_path_for_working_dir>
|
||||
|
||||
### Ollama Emulating Model and Tag
|
||||
# OLLAMA_EMULATING_MODEL_NAME=lightrag
|
||||
OLLAMA_EMULATING_MODEL_TAG=latest
|
||||
|
||||
### Max nodes return from grap retrieval in webui
|
||||
# MAX_GRAPH_NODES=1000
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,11 @@ from lightrag.constants import (
|
|||
DEFAULT_SUMMARY_LANGUAGE,
|
||||
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC,
|
||||
DEFAULT_EMBEDDING_BATCH_NUM,
|
||||
DEFAULT_OLLAMA_MODEL_NAME,
|
||||
DEFAULT_OLLAMA_MODEL_TAG,
|
||||
DEFAULT_OLLAMA_MODEL_SIZE,
|
||||
DEFAULT_OLLAMA_CREATED_AT,
|
||||
DEFAULT_OLLAMA_DIGEST,
|
||||
)
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
|
|
@ -35,13 +40,36 @@ load_dotenv(dotenv_path=".env", override=False)
|
|||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
def __init__(self, name=None, tag=None):
|
||||
self._lightrag_name = name or os.getenv(
|
||||
"OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME
|
||||
)
|
||||
self._lightrag_tag = tag or os.getenv(
|
||||
"OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG
|
||||
)
|
||||
self.LIGHTRAG_SIZE = DEFAULT_OLLAMA_MODEL_SIZE
|
||||
self.LIGHTRAG_CREATED_AT = DEFAULT_OLLAMA_CREATED_AT
|
||||
self.LIGHTRAG_DIGEST = DEFAULT_OLLAMA_DIGEST
|
||||
|
||||
@property
|
||||
def LIGHTRAG_NAME(self):
|
||||
return self._lightrag_name
|
||||
|
||||
@LIGHTRAG_NAME.setter
|
||||
def LIGHTRAG_NAME(self, value):
|
||||
self._lightrag_name = value
|
||||
|
||||
@property
|
||||
def LIGHTRAG_TAG(self):
|
||||
return self._lightrag_tag
|
||||
|
||||
@LIGHTRAG_TAG.setter
|
||||
def LIGHTRAG_TAG(self, value):
|
||||
self._lightrag_tag = value
|
||||
|
||||
@property
|
||||
def LIGHTRAG_MODEL(self):
|
||||
return f"{self._lightrag_name}:{self._lightrag_tag}"
|
||||
|
||||
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
|
@ -166,14 +194,19 @@ def parse_args() -> argparse.Namespace:
|
|||
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||
)
|
||||
|
||||
# Ollama model name
|
||||
# Ollama model configuration
|
||||
parser.add_argument(
|
||||
"--simulated-model-name",
|
||||
type=str,
|
||||
default=get_env_value(
|
||||
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
||||
),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
default=get_env_value("OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME),
|
||||
help="Name for the simulated Ollama model (default: from env or lightrag)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulated-model-tag",
|
||||
type=str,
|
||||
default=get_env_value("OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG),
|
||||
help="Tag for the simulated Ollama model (default: from env or latest)",
|
||||
)
|
||||
|
||||
# Namespace
|
||||
|
|
@ -333,7 +366,8 @@ def parse_args() -> argparse.Namespace:
|
|||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||
)
|
||||
|
||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||
|
||||
return args
|
||||
|
||||
|
|
|
|||
|
|
@ -335,6 +335,13 @@ def create_app(args):
|
|||
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
|
||||
)
|
||||
|
||||
# Create ollama_server_infos from command line arguments
|
||||
from lightrag.api.config import OllamaServerInfos
|
||||
|
||||
ollama_server_infos = OllamaServerInfos(
|
||||
name=args.simulated_model_name, tag=args.simulated_model_tag
|
||||
)
|
||||
|
||||
# Initialize RAG
|
||||
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
||||
rag = LightRAG(
|
||||
|
|
@ -373,6 +380,7 @@ def create_app(args):
|
|||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
addon_params={"language": args.summary_language},
|
||||
ollama_server_infos=ollama_server_infos,
|
||||
)
|
||||
else: # azure_openai
|
||||
rag = LightRAG(
|
||||
|
|
@ -402,6 +410,7 @@ def create_app(args):
|
|||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
addon_params={"language": args.summary_language},
|
||||
ollama_server_infos=ollama_server_infos,
|
||||
)
|
||||
|
||||
# Add routes
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import asyncio
|
|||
from ascii_colors import trace_exception
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import TiktokenTokenizer
|
||||
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from fastapi import Depends
|
||||
|
||||
|
||||
|
|
@ -221,7 +221,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
|
|||
class OllamaAPI:
|
||||
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
|
||||
self.rag = rag
|
||||
self.ollama_server_infos = ollama_server_infos
|
||||
self.ollama_server_infos = rag.ollama_server_infos
|
||||
self.top_k = top_k
|
||||
self.api_key = api_key
|
||||
self.router = APIRouter(tags=["ollama"])
|
||||
|
|
|
|||
|
|
@ -47,3 +47,10 @@ DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computation
|
|||
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
||||
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
|
||||
DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
|
||||
|
||||
# Ollama server configuration defaults
|
||||
DEFAULT_OLLAMA_MODEL_NAME = "lightrag"
|
||||
DEFAULT_OLLAMA_MODEL_TAG = "latest"
|
||||
DEFAULT_OLLAMA_MODEL_SIZE = 7365960935
|
||||
DEFAULT_OLLAMA_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
DEFAULT_OLLAMA_DIGEST = "sha256:lightrag"
|
||||
|
|
|
|||
|
|
@ -41,6 +41,12 @@ from lightrag.kg import (
|
|||
verify_storage_implementation,
|
||||
)
|
||||
|
||||
# Import for type annotation
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lightrag.api.config import OllamaServerInfos
|
||||
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_pipeline_status_lock,
|
||||
|
|
@ -342,6 +348,9 @@ class LightRAG:
|
|||
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
|
||||
)
|
||||
|
||||
ollama_server_infos: Optional["OllamaServerInfos"] = field(default=None)
|
||||
"""Configuration for Ollama server information."""
|
||||
|
||||
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
@ -403,6 +412,12 @@ class LightRAG:
|
|||
else:
|
||||
self.tokenizer = TiktokenTokenizer()
|
||||
|
||||
# Initialize ollama_server_infos if not provided
|
||||
if self.ollama_server_infos is None:
|
||||
from lightrag.api.config import OllamaServerInfos
|
||||
|
||||
self.ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
# Fix global_config now
|
||||
global_config = asdict(self)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue