diff --git a/env.example b/env.example index fda14cbb..ab12cc41 100644 --- a/env.example +++ b/env.example @@ -7,7 +7,6 @@ HOST=0.0.0.0 PORT=9621 WEBUI_TITLE='My Graph KB' WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System" -OLLAMA_EMULATING_MODEL_TAG=latest # WORKERS=2 # CORS_ORIGINS=http://localhost:3000,http://localhost:8080 @@ -21,6 +20,10 @@ OLLAMA_EMULATING_MODEL_TAG=latest # INPUT_DIR= # WORKING_DIR= +### Ollama Emulating Model and Tag +# OLLAMA_EMULATING_MODEL_NAME=lightrag +OLLAMA_EMULATING_MODEL_TAG=latest + ### Max nodes return from grap retrieval in webui # MAX_GRAPH_NODES=1000 diff --git a/lightrag/api/config.py b/lightrag/api/config.py index e56b1749..2302981d 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -26,6 +26,11 @@ from lightrag.constants import ( DEFAULT_SUMMARY_LANGUAGE, DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, DEFAULT_EMBEDDING_BATCH_NUM, + DEFAULT_OLLAMA_MODEL_NAME, + DEFAULT_OLLAMA_MODEL_TAG, + DEFAULT_OLLAMA_MODEL_SIZE, + DEFAULT_OLLAMA_CREATED_AT, + DEFAULT_OLLAMA_DIGEST, ) # use the .env that is inside the current folder @@ -35,13 +40,36 @@ load_dotenv(dotenv_path=".env", override=False) class OllamaServerInfos: - # Constants for emulated Ollama model information - LIGHTRAG_NAME = "lightrag" - LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") - LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" - LIGHTRAG_SIZE = 7365960935 # it's a dummy value - LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" - LIGHTRAG_DIGEST = "sha256:lightrag" + def __init__(self, name=None, tag=None): + self._lightrag_name = name or os.getenv( + "OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME + ) + self._lightrag_tag = tag or os.getenv( + "OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG + ) + self.LIGHTRAG_SIZE = DEFAULT_OLLAMA_MODEL_SIZE + self.LIGHTRAG_CREATED_AT = DEFAULT_OLLAMA_CREATED_AT + self.LIGHTRAG_DIGEST = DEFAULT_OLLAMA_DIGEST + + @property + def LIGHTRAG_NAME(self): + return self._lightrag_name + + @LIGHTRAG_NAME.setter + def LIGHTRAG_NAME(self, value): + self._lightrag_name = value + + @property + def LIGHTRAG_TAG(self): + return self._lightrag_tag + + @LIGHTRAG_TAG.setter + def LIGHTRAG_TAG(self, value): + self._lightrag_tag = value + + @property + def LIGHTRAG_MODEL(self): + return f"{self._lightrag_name}:{self._lightrag_tag}" ollama_server_infos = OllamaServerInfos() @@ -166,14 +194,19 @@ def parse_args() -> argparse.Namespace: help="Path to SSL private key file (required if --ssl is enabled)", ) - # Ollama model name + # Ollama model configuration parser.add_argument( "--simulated-model-name", type=str, - default=get_env_value( - "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL - ), - help="Number of conversation history turns to include (default: from env or 3)", + default=get_env_value("OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME), + help="Name for the simulated Ollama model (default: from env or lightrag)", + ) + + parser.add_argument( + "--simulated-model-tag", + type=str, + default=get_env_value("OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG), + help="Tag for the simulated Ollama model (default: from env or latest)", ) # Namespace @@ -333,7 +366,8 @@ def parse_args() -> argparse.Namespace: "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int ) - ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name + ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name + ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag return args diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 2ae2d87b..75e6526f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -335,6 +335,13 @@ def create_app(args): "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." ) + # Create ollama_server_infos from command line arguments + from lightrag.api.config import OllamaServerInfos + + ollama_server_infos = OllamaServerInfos( + name=args.simulated_model_name, tag=args.simulated_model_tag + ) + # Initialize RAG if args.llm_binding in ["lollms", "ollama", "openai"]: rag = LightRAG( @@ -373,6 +380,7 @@ def create_app(args): max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, addon_params={"language": args.summary_language}, + ollama_server_infos=ollama_server_infos, ) else: # azure_openai rag = LightRAG( @@ -402,6 +410,7 @@ def create_app(args): max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, addon_params={"language": args.summary_language}, + ollama_server_infos=ollama_server_infos, ) # Add routes diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 597009bc..c38018f6 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -11,7 +11,7 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import TiktokenTokenizer -from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency +from lightrag.api.utils_api import get_combined_auth_dependency from fastapi import Depends @@ -221,7 +221,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]: class OllamaAPI: def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None): self.rag = rag - self.ollama_server_infos = ollama_server_infos + self.ollama_server_infos = rag.ollama_server_infos self.top_k = top_k self.api_key = api_key self.router = APIRouter(tags=["ollama"]) diff --git a/lightrag/constants.py b/lightrag/constants.py index 7fc47e21..bea8a00f 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -47,3 +47,10 @@ DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computation DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename + +# Ollama server configuration defaults +DEFAULT_OLLAMA_MODEL_NAME = "lightrag" +DEFAULT_OLLAMA_MODEL_TAG = "latest" +DEFAULT_OLLAMA_MODEL_SIZE = 7365960935 +DEFAULT_OLLAMA_CREATED_AT = "2024-01-15T00:00:00Z" +DEFAULT_OLLAMA_DIGEST = "sha256:lightrag" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index c930dc66..5df36bbb 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -41,6 +41,12 @@ from lightrag.kg import ( verify_storage_implementation, ) +# Import for type annotation +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lightrag.api.config import OllamaServerInfos + from lightrag.kg.shared_storage import ( get_namespace_data, get_pipeline_status_lock, @@ -342,6 +348,9 @@ class LightRAG: default=float(os.getenv("COSINE_THRESHOLD", 0.2)) ) + ollama_server_infos: Optional["OllamaServerInfos"] = field(default=None) + """Configuration for Ollama server information.""" + _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): @@ -403,6 +412,12 @@ class LightRAG: else: self.tokenizer = TiktokenTokenizer() + # Initialize ollama_server_infos if not provided + if self.ollama_server_infos is None: + from lightrag.api.config import OllamaServerInfos + + self.ollama_server_infos = OllamaServerInfos() + # Fix global_config now global_config = asdict(self)