Fix(Ollama option): change stop option from string to list and add fallback global temperature setting

This commit is contained in:
yangdx 2025-08-04 19:43:14 +08:00
parent f8a880ac66
commit e5e3f0f878
2 changed files with 147 additions and 36 deletions

View file

@ -121,11 +121,13 @@ LLM_BINDING_HOST=https://api.openai.com/v1
LLM_BINDING_API_KEY=your_api_key LLM_BINDING_API_KEY=your_api_key
### Most Commont Parameters for Ollama Server ### Most Commont Parameters for Ollama Server
### see also env.ollama-binding-options.example for fine tuning ollama
### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000
# OLLAMA_LLM_NUM_CTX=32768
### Time out in seconds, None for infinite timeout ### Time out in seconds, None for infinite timeout
TIMEOUT=240 TIMEOUT=240
### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000
OLLAMA_LLM_NUM_CTX=32768
### Stop sequences for Ollama LLM
# OLLAMA_LLM_STOP='["</s>", "Assistant:", "\n\n"]'
### see also env.ollama-binding-options.example for fine tuning ollama
### Optional for Azure ### Optional for Azure
# AZURE_OPENAI_API_VERSION=2024-08-01-preview # AZURE_OPENAI_API_VERSION=2024-08-01-preview

View file

@ -7,10 +7,13 @@ bindings and integrations.
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
import argparse import argparse
from dataclasses import asdict, dataclass import json
from typing import Any, ClassVar import os
from dataclasses import asdict, dataclass, field, MISSING
from typing import Any, ClassVar, List
from lightrag.utils import get_env_value from lightrag.utils import get_env_value, logger
from lightrag.constants import DEFAULT_TEMPERATURE
# ============================================================================= # =============================================================================
@ -96,34 +99,101 @@ class BindingOptions:
def add_args(cls, parser: ArgumentParser): def add_args(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls._binding_name} binding options") group = parser.add_argument_group(f"{cls._binding_name} binding options")
for arg_item in cls.args_env_name_type_value(): for arg_item in cls.args_env_name_type_value():
group.add_argument( # Handle JSON parsing for list types
f"--{arg_item['argname']}", if arg_item["type"] == List[str]:
type=arg_item["type"],
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS), def json_list_parser(value):
help=arg_item["help"], try:
) parsed = json.loads(value)
if not isinstance(parsed, list):
raise argparse.ArgumentTypeError(
f"Expected JSON array, got {type(parsed).__name__}"
)
return parsed
except json.JSONDecodeError as e:
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
# Get environment variable with JSON parsing
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
if env_value is not argparse.SUPPRESS:
try:
env_value = json_list_parser(env_value)
except argparse.ArgumentTypeError:
env_value = argparse.SUPPRESS
group.add_argument(
f"--{arg_item['argname']}",
type=json_list_parser,
default=env_value,
help=arg_item["help"],
)
else:
group.add_argument(
f"--{arg_item['argname']}",
type=arg_item["type"],
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
@classmethod @classmethod
def args_env_name_type_value(cls): def args_env_name_type_value(cls):
import dataclasses
args_prefix = f"{cls._binding_name}".replace("_", "-") args_prefix = f"{cls._binding_name}".replace("_", "-")
env_var_prefix = f"{cls._binding_name}_".upper() env_var_prefix = f"{cls._binding_name}_".upper()
class_vars = {
key: value
for key, value in cls._all_class_vars(cls).items()
if not callable(value) and not key.startswith("_")
}
help = cls._help help = cls._help
for class_var in class_vars: # Check if this is a dataclass and use dataclass fields
argdef = { if dataclasses.is_dataclass(cls):
"argname": f"{args_prefix}-{class_var}", for field in dataclasses.fields(cls):
"env_name": f"{env_var_prefix}{class_var.upper()}", # Skip private fields
"type": type(class_vars[class_var]), if field.name.startswith("_"):
"default": class_vars[class_var], continue
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
# Get default value
if field.default is not dataclasses.MISSING:
default_value = field.default
elif field.default_factory is not dataclasses.MISSING:
default_value = field.default_factory()
else:
default_value = None
argdef = {
"argname": f"{args_prefix}-{field.name}",
"env_name": f"{env_var_prefix}{field.name.upper()}",
"type": field.type,
"default": default_value,
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
}
yield argdef
else:
# Fallback to old method for non-dataclass classes
class_vars = {
key: value
for key, value in cls._all_class_vars(cls).items()
if not callable(value) and not key.startswith("_")
} }
yield argdef # Get type hints to properly detect List[str] types
type_hints = {}
for base in cls.__mro__:
if hasattr(base, "__annotations__"):
type_hints.update(base.__annotations__)
for class_var in class_vars:
# Use type hint if available, otherwise fall back to type of value
var_type = type_hints.get(class_var, type(class_vars[class_var]))
argdef = {
"argname": f"{args_prefix}-{class_var}",
"env_name": f"{env_var_prefix}{class_var.upper()}",
"type": var_type,
"default": class_vars[class_var],
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
}
yield argdef
@classmethod @classmethod
def generate_dot_env_sample(cls): def generate_dot_env_sample(cls):
@ -164,9 +234,14 @@ class BindingOptions:
for arg_item in klass.args_env_name_type_value(): for arg_item in klass.args_env_name_type_value():
if arg_item["help"]: if arg_item["help"]:
sample_stream.write(f"# {arg_item['help']}\n") sample_stream.write(f"# {arg_item['help']}\n")
sample_stream.write(
f"# {arg_item['env_name']}={arg_item['default']}\n\n" # Handle JSON formatting for list types
) if arg_item["type"] == List[str]:
default_value = json.dumps(arg_item["default"])
else:
default_value = arg_item["default"]
sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n")
sample_stream.write(sample_bottom) sample_stream.write(sample_bottom)
return sample_stream.getvalue() return sample_stream.getvalue()
@ -256,7 +331,7 @@ class _OllamaOptionsMixin:
seed: int = -1 # Random seed for generation (-1 for random) seed: int = -1 # Random seed for generation (-1 for random)
# Sampling parameters # Sampling parameters
temperature: float = 0.8 # Controls randomness (0.0-2.0) temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0-2.0)
top_k: int = 40 # Top-k sampling parameter top_k: int = 40 # Top-k sampling parameter
top_p: float = 0.9 # Top-p (nucleus) sampling parameter top_p: float = 0.9 # Top-p (nucleus) sampling parameter
tfs_z: float = 1.0 # Tail free sampling parameter tfs_z: float = 1.0 # Tail free sampling parameter
@ -295,7 +370,7 @@ class _OllamaOptionsMixin:
# Output control # Output control
penalize_newline: bool = True # Penalize newline tokens penalize_newline: bool = True # Penalize newline tokens
stop: str = "" # Stop sequences (comma-separated) stop: List[str] = field(default_factory=list) # Stop sequences
# optional help strings # optional help strings
_help: ClassVar[dict[str, str]] = { _help: ClassVar[dict[str, str]] = {
@ -329,7 +404,7 @@ class _OllamaOptionsMixin:
"use_mlock": "Lock model in memory", "use_mlock": "Lock model in memory",
"embedding_only": "Only use for embeddings", "embedding_only": "Only use for embeddings",
"penalize_newline": "Penalize newline tokens", "penalize_newline": "Penalize newline tokens",
"stop": "Stop sequences (comma-separated string)", "stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
} }
@ -362,9 +437,44 @@ class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions):
class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions): class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
"""Options for Ollama LLM with specialized configuration for LLM tasks.""" """Options for Ollama LLM with specialized configuration for LLM tasks."""
# Override temperature field to track if it was explicitly set
temperature: float = field(default_factory=lambda: MISSING)
# mandatory name of binding # mandatory name of binding
_binding_name: ClassVar[str] = "ollama_llm" _binding_name: ClassVar[str] = "ollama_llm"
def __post_init__(self):
"""Handle temperature parameter with correct priority logic"""
# If temperature was not explicitly set, apply priority logic
if self.temperature is MISSING:
# Check OLLAMA_LLM_TEMPERATURE first (highest priority for env vars)
ollama_temp = os.getenv("OLLAMA_LLM_TEMPERATURE")
if ollama_temp is not None:
try:
self.temperature = float(ollama_temp)
logger.debug(f"Using OLLAMA_LLM_TEMPERATURE: {self.temperature}")
return
except (ValueError, TypeError):
logger.warning(
f"Invalid OLLAMA_LLM_TEMPERATURE value: {ollama_temp}"
)
# Check TEMPERATURE as fallback
general_temp = os.getenv("TEMPERATURE")
if general_temp is not None:
try:
self.temperature = float(general_temp)
logger.debug(
f"Using TEMPERATURE environment variable: {self.temperature}"
)
return
except (ValueError, TypeError):
logger.warning(f"Invalid TEMPERATURE value: {general_temp}")
# Use default value
self.temperature = DEFAULT_TEMPERATURE
logger.debug(f"Using default temperature: {self.temperature}")
# ============================================================================= # =============================================================================
# Additional LLM Provider Bindings # Additional LLM Provider Bindings
@ -423,15 +533,14 @@ if __name__ == "__main__":
# from io import StringIO # from io import StringIO
dotenv.load_dotenv(dotenv_path=".env", override=False) dotenv.load_dotenv(dotenv_path=".env", override=False)
# env_strstream = StringIO( # env_strstream = StringIO(
# ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n") # ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n")
# ) # )
# # Load environment variables from .env file # # Load environment variables from .env file
# dotenv.load_dotenv(stream=env_strstream) # dotenv.load_dotenv(stream=env_strstream)
if len(sys.argv) > 1 and sys.argv[1] == "test": if len(sys.argv) > 1 and sys.argv[1] == "test":
# Add arguments for OllamaEmbeddingOptions and OllamaLLMOptions # Add arguments for OllamaEmbeddingOptions and OllamaLLMOptions
parser = ArgumentParser(description="Test Ollama binding") parser = ArgumentParser(description="Test Ollama binding")
OllamaEmbeddingOptions.add_args(parser) OllamaEmbeddingOptions.add_args(parser)
@ -444,8 +553,8 @@ if __name__ == "__main__":
"1024", "1024",
"--ollama-llm-num_ctx", "--ollama-llm-num_ctx",
"2048", "2048",
"--ollama-llm-stop", # "--ollama-llm-stop",
"<end>", # '["</s>", "\\n\\n"]',
] ]
) )
print("Final args for LLM and Embedding:") print("Final args for LLM and Embedding:")