Refactor LLM temperature handling to be provider-specific
• Remove global temperature parameter • Add provider-specific temp configs • Update env example with new settings • Fix Bedrock temperature handling • Clean up splash screen display
This commit is contained in:
parent
df7bcb1e3d
commit
aa22772721
9 changed files with 20 additions and 54 deletions
|
|
@ -127,9 +127,7 @@ MAX_PARALLEL_INSERT=2
|
||||||
### LLM Configuration
|
### LLM Configuration
|
||||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
||||||
###########################################################
|
###########################################################
|
||||||
### LLM temperature and timeout setting for all llm binding (openai, azure_openai, ollama)
|
### LLM request timeout setting for all llm (set to TIMEOUT if not specified, 0 means no timeout for Ollma)
|
||||||
# TEMPERATURE=1.0
|
|
||||||
### LLM request timeout setting for all llm (set to TIMEOUT if not specified)
|
|
||||||
# LLM_TIMEOUT=150
|
# LLM_TIMEOUT=150
|
||||||
### Some models like o1-mini require temperature to be set to 1, some LLM can fall into output loops with low temperature
|
### Some models like o1-mini require temperature to be set to 1, some LLM can fall into output loops with low temperature
|
||||||
|
|
||||||
|
|
@ -151,6 +149,7 @@ LLM_BINDING_API_KEY=your_api_key
|
||||||
### OpenAI Specific Parameters
|
### OpenAI Specific Parameters
|
||||||
### Apply frequency penalty to prevent the LLM from generating repetitive or looping outputs
|
### Apply frequency penalty to prevent the LLM from generating repetitive or looping outputs
|
||||||
# OPENAI_LLM_FREQUENCY_PENALTY=1.1
|
# OPENAI_LLM_FREQUENCY_PENALTY=1.1
|
||||||
|
# OPENAI_LLM_TEMPERATURE=1.0
|
||||||
### use the following command to see all support options for openai and azure_openai
|
### use the following command to see all support options for openai and azure_openai
|
||||||
### lightrag-server --llm-binding openai --help
|
### lightrag-server --llm-binding openai --help
|
||||||
|
|
||||||
|
|
@ -164,6 +163,9 @@ OLLAMA_LLM_NUM_CTX=32768
|
||||||
### use the following command to see all support options for Ollama LLM
|
### use the following command to see all support options for Ollama LLM
|
||||||
### lightrag-server --llm-binding ollama --help
|
### lightrag-server --llm-binding ollama --help
|
||||||
|
|
||||||
|
### Bedrock Specific Parameters
|
||||||
|
# BEDROCK_LLM_TEMPERATURE=1.0
|
||||||
|
|
||||||
####################################################################################
|
####################################################################################
|
||||||
### Embedding Configuration (Should not be changed after the first file processed)
|
### Embedding Configuration (Should not be changed after the first file processed)
|
||||||
####################################################################################
|
####################################################################################
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,6 @@ from lightrag.constants import (
|
||||||
DEFAULT_EMBEDDING_BATCH_NUM,
|
DEFAULT_EMBEDDING_BATCH_NUM,
|
||||||
DEFAULT_OLLAMA_MODEL_NAME,
|
DEFAULT_OLLAMA_MODEL_NAME,
|
||||||
DEFAULT_OLLAMA_MODEL_TAG,
|
DEFAULT_OLLAMA_MODEL_TAG,
|
||||||
DEFAULT_TEMPERATURE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# use the .env that is inside the current folder
|
# use the .env that is inside the current folder
|
||||||
|
|
@ -264,14 +263,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
||||||
OpenAILLMOptions.add_args(parser)
|
OpenAILLMOptions.add_args(parser)
|
||||||
|
|
||||||
# Add global temperature command line argument
|
|
||||||
parser.add_argument(
|
|
||||||
"--temperature",
|
|
||||||
type=float,
|
|
||||||
default=get_env_value("TEMPERATURE", DEFAULT_TEMPERATURE, float),
|
|
||||||
help="Global temperature setting for LLM (default: from env TEMPERATURE or 0.1)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# convert relative path to absolute path
|
# convert relative path to absolute path
|
||||||
|
|
@ -330,32 +321,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
|
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
|
||||||
|
|
||||||
# Handle Ollama LLM temperature with priority cascade when llm-binding is ollama
|
|
||||||
if args.llm_binding == "ollama":
|
|
||||||
# Priority order (highest to lowest):
|
|
||||||
# 1. --ollama-llm-temperature command argument
|
|
||||||
# 2. OLLAMA_LLM_TEMPERATURE environment variable
|
|
||||||
# 3. --temperature command argument
|
|
||||||
# 4. TEMPERATURE environment variable
|
|
||||||
|
|
||||||
# Check if --ollama-llm-temperature was explicitly provided in command line
|
|
||||||
if "--ollama-llm-temperature" not in sys.argv:
|
|
||||||
# Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
|
|
||||||
args.ollama_llm_temperature = args.temperature
|
|
||||||
|
|
||||||
# Handle OpenAI LLM temperature with priority cascade when llm-binding is openai or azure_openai
|
|
||||||
if args.llm_binding in ["openai", "azure_openai"]:
|
|
||||||
# Priority order (highest to lowest):
|
|
||||||
# 1. --openai-llm-temperature command argument
|
|
||||||
# 2. OPENAI_LLM_TEMPERATURE environment variable
|
|
||||||
# 3. --temperature command argument
|
|
||||||
# 4. TEMPERATURE environment variable
|
|
||||||
|
|
||||||
# Check if --openai-llm-temperature was explicitly provided in command line
|
|
||||||
if "--openai-llm-temperature" not in sys.argv:
|
|
||||||
# Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
|
|
||||||
args.openai_llm_temperature = args.temperature
|
|
||||||
|
|
||||||
# Select Document loading tool (DOCLING, DEFAULT)
|
# Select Document loading tool (DOCLING, DEFAULT)
|
||||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -327,7 +327,7 @@ def create_app(args):
|
||||||
history_messages = []
|
history_messages = []
|
||||||
|
|
||||||
# Use global temperature for Bedrock
|
# Use global temperature for Bedrock
|
||||||
kwargs["temperature"] = args.temperature
|
kwargs["temperature"] = get_env_value("BEDROCK_LLM_TEMPERATURE", 1.0, float)
|
||||||
|
|
||||||
return await bedrock_complete_if_cache(
|
return await bedrock_complete_if_cache(
|
||||||
args.llm_model,
|
args.llm_model,
|
||||||
|
|
@ -479,9 +479,6 @@ def create_app(args):
|
||||||
llm_model_func=azure_openai_model_complete,
|
llm_model_func=azure_openai_model_complete,
|
||||||
chunk_token_size=int(args.chunk_size),
|
chunk_token_size=int(args.chunk_size),
|
||||||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
||||||
llm_model_kwargs={
|
|
||||||
"timeout": llm_timeout,
|
|
||||||
},
|
|
||||||
llm_model_name=args.llm_model,
|
llm_model_name=args.llm_model,
|
||||||
llm_model_max_async=args.max_async,
|
llm_model_max_async=args.max_async,
|
||||||
summary_max_tokens=args.max_tokens,
|
summary_max_tokens=args.max_tokens,
|
||||||
|
|
|
||||||
|
|
@ -201,6 +201,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||||
ASCIIColors.yellow(f"{args.port}")
|
ASCIIColors.yellow(f"{args.port}")
|
||||||
ASCIIColors.white(" ├─ Workers: ", end="")
|
ASCIIColors.white(" ├─ Workers: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.workers}")
|
ASCIIColors.yellow(f"{args.workers}")
|
||||||
|
ASCIIColors.white(" ├─ Timeout: ", end="")
|
||||||
|
ASCIIColors.yellow(f"{args.timeout}")
|
||||||
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.cors_origins}")
|
ASCIIColors.yellow(f"{args.cors_origins}")
|
||||||
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
||||||
|
|
@ -238,14 +240,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||||
ASCIIColors.yellow(f"{args.llm_binding_host}")
|
ASCIIColors.yellow(f"{args.llm_binding_host}")
|
||||||
ASCIIColors.white(" ├─ Model: ", end="")
|
ASCIIColors.white(" ├─ Model: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.llm_model}")
|
ASCIIColors.yellow(f"{args.llm_model}")
|
||||||
ASCIIColors.white(" ├─ Temperature: ", end="")
|
|
||||||
ASCIIColors.yellow(f"{args.temperature}")
|
|
||||||
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
|
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.max_async}")
|
ASCIIColors.yellow(f"{args.max_async}")
|
||||||
ASCIIColors.white(" ├─ Max Tokens: ", end="")
|
ASCIIColors.white(" ├─ Max Tokens: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.max_tokens}")
|
ASCIIColors.yellow(f"{args.max_tokens}")
|
||||||
ASCIIColors.white(" ├─ Timeout: ", end="")
|
|
||||||
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
|
||||||
ASCIIColors.white(" ├─ LLM Cache Enabled: ", end="")
|
ASCIIColors.white(" ├─ LLM Cache Enabled: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.enable_llm_cache}")
|
ASCIIColors.yellow(f"{args.enable_llm_cache}")
|
||||||
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
|
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
|
||||||
llm_instance = OpenAI(
|
llm_instance = OpenAI(
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
api_key="your-openai-key",
|
api_key="your-openai-key",
|
||||||
temperature=0.7,
|
|
||||||
)
|
)
|
||||||
kwargs['llm_instance'] = llm_instance
|
kwargs['llm_instance'] = llm_instance
|
||||||
|
|
||||||
|
|
@ -91,7 +90,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
|
||||||
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
|
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
|
||||||
api_base=settings.LITELLM_URL,
|
api_base=settings.LITELLM_URL,
|
||||||
api_key=settings.LITELLM_KEY,
|
api_key=settings.LITELLM_KEY,
|
||||||
temperature=0.7,
|
|
||||||
)
|
)
|
||||||
kwargs['llm_instance'] = llm_instance
|
kwargs['llm_instance'] = llm_instance
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,10 +82,15 @@ async def anthropic_complete_if_cache(
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
anthropic_async_client = (
|
anthropic_async_client = (
|
||||||
AsyncAnthropic(default_headers=default_headers, api_key=api_key, timeout=timeout)
|
AsyncAnthropic(
|
||||||
|
default_headers=default_headers, api_key=api_key, timeout=timeout
|
||||||
|
)
|
||||||
if base_url is None
|
if base_url is None
|
||||||
else AsyncAnthropic(
|
else AsyncAnthropic(
|
||||||
base_url=base_url, default_headers=default_headers, api_key=api_key, timeout=timeout
|
base_url=base_url,
|
||||||
|
default_headers=default_headers,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ async def azure_openai_complete_if_cache(
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
openai_async_client = AsyncAzureOpenAI(
|
openai_async_client = AsyncAzureOpenAI(
|
||||||
azure_endpoint=base_url,
|
azure_endpoint=base_url,
|
||||||
azure_deployment=deployment,
|
azure_deployment=deployment,
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ async def lollms_model_if_cache(
|
||||||
"personality": kwargs.get("personality", -1),
|
"personality": kwargs.get("personality", -1),
|
||||||
"n_predict": kwargs.get("n_predict", None),
|
"n_predict": kwargs.get("n_predict", None),
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"temperature": kwargs.get("temperature", 0.8),
|
"temperature": kwargs.get("temperature", 1.0),
|
||||||
"top_k": kwargs.get("top_k", 50),
|
"top_k": kwargs.get("top_k", 50),
|
||||||
"top_p": kwargs.get("top_p", 0.95),
|
"top_p": kwargs.get("top_p", 0.95),
|
||||||
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
|
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
|
||||||
|
|
|
||||||
|
|
@ -158,10 +158,11 @@ async def openai_complete_if_cache(
|
||||||
|
|
||||||
# Create the OpenAI client
|
# Create the OpenAI client
|
||||||
openai_async_client = create_openai_async_client(
|
openai_async_client = create_openai_async_client(
|
||||||
api_key=api_key, base_url=base_url, client_configs=client_configs,
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
client_configs=client_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Prepare messages
|
# Prepare messages
|
||||||
messages: list[dict[str, Any]] = []
|
messages: list[dict[str, Any]] = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue