feat: refactor ollama server configuration management
- Add ollama_server_infos attribute to LightRAG class with default initialization - Move default values to constants.py for centralized configuration - Refactor OllamaServerInfos class with property accessors and CLI support - Update OllamaAPI to get configuration through rag object instead of direct import - Add command line arguments for simulated model name and tag - Fix type imports to avoid circular dependencies
This commit is contained in:
parent
598eecd06d
commit
f2ffff063b
6 changed files with 84 additions and 16 deletions
|
|
@ -7,7 +7,6 @@ HOST=0.0.0.0
|
||||||
PORT=9621
|
PORT=9621
|
||||||
WEBUI_TITLE='My Graph KB'
|
WEBUI_TITLE='My Graph KB'
|
||||||
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
||||||
OLLAMA_EMULATING_MODEL_TAG=latest
|
|
||||||
# WORKERS=2
|
# WORKERS=2
|
||||||
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
||||||
|
|
||||||
|
|
@ -21,6 +20,10 @@ OLLAMA_EMULATING_MODEL_TAG=latest
|
||||||
# INPUT_DIR=<absolute_path_for_doc_input_dir>
|
# INPUT_DIR=<absolute_path_for_doc_input_dir>
|
||||||
# WORKING_DIR=<absolute_path_for_working_dir>
|
# WORKING_DIR=<absolute_path_for_working_dir>
|
||||||
|
|
||||||
|
### Ollama Emulating Model and Tag
|
||||||
|
# OLLAMA_EMULATING_MODEL_NAME=lightrag
|
||||||
|
OLLAMA_EMULATING_MODEL_TAG=latest
|
||||||
|
|
||||||
### Max nodes return from grap retrieval in webui
|
### Max nodes return from grap retrieval in webui
|
||||||
# MAX_GRAPH_NODES=1000
|
# MAX_GRAPH_NODES=1000
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,11 @@ from lightrag.constants import (
|
||||||
DEFAULT_SUMMARY_LANGUAGE,
|
DEFAULT_SUMMARY_LANGUAGE,
|
||||||
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC,
|
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC,
|
||||||
DEFAULT_EMBEDDING_BATCH_NUM,
|
DEFAULT_EMBEDDING_BATCH_NUM,
|
||||||
|
DEFAULT_OLLAMA_MODEL_NAME,
|
||||||
|
DEFAULT_OLLAMA_MODEL_TAG,
|
||||||
|
DEFAULT_OLLAMA_MODEL_SIZE,
|
||||||
|
DEFAULT_OLLAMA_CREATED_AT,
|
||||||
|
DEFAULT_OLLAMA_DIGEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
# use the .env that is inside the current folder
|
# use the .env that is inside the current folder
|
||||||
|
|
@ -35,13 +40,36 @@ load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
|
|
||||||
class OllamaServerInfos:
|
class OllamaServerInfos:
|
||||||
# Constants for emulated Ollama model information
|
def __init__(self, name=None, tag=None):
|
||||||
LIGHTRAG_NAME = "lightrag"
|
self._lightrag_name = name or os.getenv(
|
||||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
"OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME
|
||||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
)
|
||||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
self._lightrag_tag = tag or os.getenv(
|
||||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
"OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG
|
||||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
)
|
||||||
|
self.LIGHTRAG_SIZE = DEFAULT_OLLAMA_MODEL_SIZE
|
||||||
|
self.LIGHTRAG_CREATED_AT = DEFAULT_OLLAMA_CREATED_AT
|
||||||
|
self.LIGHTRAG_DIGEST = DEFAULT_OLLAMA_DIGEST
|
||||||
|
|
||||||
|
@property
|
||||||
|
def LIGHTRAG_NAME(self):
|
||||||
|
return self._lightrag_name
|
||||||
|
|
||||||
|
@LIGHTRAG_NAME.setter
|
||||||
|
def LIGHTRAG_NAME(self, value):
|
||||||
|
self._lightrag_name = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def LIGHTRAG_TAG(self):
|
||||||
|
return self._lightrag_tag
|
||||||
|
|
||||||
|
@LIGHTRAG_TAG.setter
|
||||||
|
def LIGHTRAG_TAG(self, value):
|
||||||
|
self._lightrag_tag = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def LIGHTRAG_MODEL(self):
|
||||||
|
return f"{self._lightrag_name}:{self._lightrag_tag}"
|
||||||
|
|
||||||
|
|
||||||
ollama_server_infos = OllamaServerInfos()
|
ollama_server_infos = OllamaServerInfos()
|
||||||
|
|
@ -166,14 +194,19 @@ def parse_args() -> argparse.Namespace:
|
||||||
help="Path to SSL private key file (required if --ssl is enabled)",
|
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ollama model name
|
# Ollama model configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--simulated-model-name",
|
"--simulated-model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default=get_env_value(
|
default=get_env_value("OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME),
|
||||||
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
help="Name for the simulated Ollama model (default: from env or lightrag)",
|
||||||
),
|
)
|
||||||
help="Number of conversation history turns to include (default: from env or 3)",
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulated-model-tag",
|
||||||
|
type=str,
|
||||||
|
default=get_env_value("OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG),
|
||||||
|
help="Tag for the simulated Ollama model (default: from env or latest)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Namespace
|
# Namespace
|
||||||
|
|
@ -333,7 +366,8 @@ def parse_args() -> argparse.Namespace:
|
||||||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||||
)
|
)
|
||||||
|
|
||||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||||
|
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -335,6 +335,13 @@ def create_app(args):
|
||||||
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
|
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create ollama_server_infos from command line arguments
|
||||||
|
from lightrag.api.config import OllamaServerInfos
|
||||||
|
|
||||||
|
ollama_server_infos = OllamaServerInfos(
|
||||||
|
name=args.simulated_model_name, tag=args.simulated_model_tag
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize RAG
|
# Initialize RAG
|
||||||
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
|
|
@ -373,6 +380,7 @@ def create_app(args):
|
||||||
max_parallel_insert=args.max_parallel_insert,
|
max_parallel_insert=args.max_parallel_insert,
|
||||||
max_graph_nodes=args.max_graph_nodes,
|
max_graph_nodes=args.max_graph_nodes,
|
||||||
addon_params={"language": args.summary_language},
|
addon_params={"language": args.summary_language},
|
||||||
|
ollama_server_infos=ollama_server_infos,
|
||||||
)
|
)
|
||||||
else: # azure_openai
|
else: # azure_openai
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
|
|
@ -402,6 +410,7 @@ def create_app(args):
|
||||||
max_parallel_insert=args.max_parallel_insert,
|
max_parallel_insert=args.max_parallel_insert,
|
||||||
max_graph_nodes=args.max_graph_nodes,
|
max_graph_nodes=args.max_graph_nodes,
|
||||||
addon_params={"language": args.summary_language},
|
addon_params={"language": args.summary_language},
|
||||||
|
ollama_server_infos=ollama_server_infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add routes
|
# Add routes
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import asyncio
|
||||||
from ascii_colors import trace_exception
|
from ascii_colors import trace_exception
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.utils import TiktokenTokenizer
|
from lightrag.utils import TiktokenTokenizer
|
||||||
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -221,7 +221,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
|
||||||
class OllamaAPI:
|
class OllamaAPI:
|
||||||
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
|
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
|
||||||
self.rag = rag
|
self.rag = rag
|
||||||
self.ollama_server_infos = ollama_server_infos
|
self.ollama_server_infos = rag.ollama_server_infos
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.router = APIRouter(tags=["ollama"])
|
self.router = APIRouter(tags=["ollama"])
|
||||||
|
|
|
||||||
|
|
@ -47,3 +47,10 @@ DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computation
|
||||||
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
||||||
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
|
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
|
||||||
DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
|
DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
|
||||||
|
|
||||||
|
# Ollama server configuration defaults
|
||||||
|
DEFAULT_OLLAMA_MODEL_NAME = "lightrag"
|
||||||
|
DEFAULT_OLLAMA_MODEL_TAG = "latest"
|
||||||
|
DEFAULT_OLLAMA_MODEL_SIZE = 7365960935
|
||||||
|
DEFAULT_OLLAMA_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||||
|
DEFAULT_OLLAMA_DIGEST = "sha256:lightrag"
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,12 @@ from lightrag.kg import (
|
||||||
verify_storage_implementation,
|
verify_storage_implementation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import for type annotation
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lightrag.api.config import OllamaServerInfos
|
||||||
|
|
||||||
from lightrag.kg.shared_storage import (
|
from lightrag.kg.shared_storage import (
|
||||||
get_namespace_data,
|
get_namespace_data,
|
||||||
get_pipeline_status_lock,
|
get_pipeline_status_lock,
|
||||||
|
|
@ -342,6 +348,9 @@ class LightRAG:
|
||||||
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
|
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ollama_server_infos: Optional["OllamaServerInfos"] = field(default=None)
|
||||||
|
"""Configuration for Ollama server information."""
|
||||||
|
|
||||||
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
@ -403,6 +412,12 @@ class LightRAG:
|
||||||
else:
|
else:
|
||||||
self.tokenizer = TiktokenTokenizer()
|
self.tokenizer = TiktokenTokenizer()
|
||||||
|
|
||||||
|
# Initialize ollama_server_infos if not provided
|
||||||
|
if self.ollama_server_infos is None:
|
||||||
|
from lightrag.api.config import OllamaServerInfos
|
||||||
|
|
||||||
|
self.ollama_server_infos = OllamaServerInfos()
|
||||||
|
|
||||||
# Fix global_config now
|
# Fix global_config now
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue