Fix(Ollama option): change stop option from string to list and add fallback global temperature setting

This commit is contained in:
yangdx 2025-08-04 19:43:14 +08:00
parent f8a880ac66
commit e5e3f0f878
2 changed files with 147 additions and 36 deletions

View file

@ -121,11 +121,13 @@ LLM_BINDING_HOST=https://api.openai.com/v1
LLM_BINDING_API_KEY=your_api_key
### Most Commont Parameters for Ollama Server
### see also env.ollama-binding-options.example for fine tuning ollama
### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000
# OLLAMA_LLM_NUM_CTX=32768
### Time out in seconds, None for infinite timeout
TIMEOUT=240
### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000
OLLAMA_LLM_NUM_CTX=32768
### Stop sequences for Ollama LLM
# OLLAMA_LLM_STOP='["</s>", "Assistant:", "\n\n"]'
### see also env.ollama-binding-options.example for fine tuning ollama
### Optional for Azure
# AZURE_OPENAI_API_VERSION=2024-08-01-preview

View file

@ -7,10 +7,13 @@ bindings and integrations.
from argparse import ArgumentParser, Namespace
import argparse
from dataclasses import asdict, dataclass
from typing import Any, ClassVar
import json
import os
from dataclasses import asdict, dataclass, field, MISSING
from typing import Any, ClassVar, List
from lightrag.utils import get_env_value
from lightrag.utils import get_env_value, logger
from lightrag.constants import DEFAULT_TEMPERATURE
# =============================================================================
@ -96,34 +99,101 @@ class BindingOptions:
def add_args(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls._binding_name} binding options")
for arg_item in cls.args_env_name_type_value():
group.add_argument(
f"--{arg_item['argname']}",
type=arg_item["type"],
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
# Handle JSON parsing for list types
if arg_item["type"] == List[str]:
def json_list_parser(value):
try:
parsed = json.loads(value)
if not isinstance(parsed, list):
raise argparse.ArgumentTypeError(
f"Expected JSON array, got {type(parsed).__name__}"
)
return parsed
except json.JSONDecodeError as e:
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
# Get environment variable with JSON parsing
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
if env_value is not argparse.SUPPRESS:
try:
env_value = json_list_parser(env_value)
except argparse.ArgumentTypeError:
env_value = argparse.SUPPRESS
group.add_argument(
f"--{arg_item['argname']}",
type=json_list_parser,
default=env_value,
help=arg_item["help"],
)
else:
group.add_argument(
f"--{arg_item['argname']}",
type=arg_item["type"],
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
@classmethod
def args_env_name_type_value(cls):
import dataclasses
args_prefix = f"{cls._binding_name}".replace("_", "-")
env_var_prefix = f"{cls._binding_name}_".upper()
class_vars = {
key: value
for key, value in cls._all_class_vars(cls).items()
if not callable(value) and not key.startswith("_")
}
help = cls._help
for class_var in class_vars:
argdef = {
"argname": f"{args_prefix}-{class_var}",
"env_name": f"{env_var_prefix}{class_var.upper()}",
"type": type(class_vars[class_var]),
"default": class_vars[class_var],
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
# Check if this is a dataclass and use dataclass fields
if dataclasses.is_dataclass(cls):
for field in dataclasses.fields(cls):
# Skip private fields
if field.name.startswith("_"):
continue
# Get default value
if field.default is not dataclasses.MISSING:
default_value = field.default
elif field.default_factory is not dataclasses.MISSING:
default_value = field.default_factory()
else:
default_value = None
argdef = {
"argname": f"{args_prefix}-{field.name}",
"env_name": f"{env_var_prefix}{field.name.upper()}",
"type": field.type,
"default": default_value,
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
}
yield argdef
else:
# Fallback to old method for non-dataclass classes
class_vars = {
key: value
for key, value in cls._all_class_vars(cls).items()
if not callable(value) and not key.startswith("_")
}
yield argdef
# Get type hints to properly detect List[str] types
type_hints = {}
for base in cls.__mro__:
if hasattr(base, "__annotations__"):
type_hints.update(base.__annotations__)
for class_var in class_vars:
# Use type hint if available, otherwise fall back to type of value
var_type = type_hints.get(class_var, type(class_vars[class_var]))
argdef = {
"argname": f"{args_prefix}-{class_var}",
"env_name": f"{env_var_prefix}{class_var.upper()}",
"type": var_type,
"default": class_vars[class_var],
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
}
yield argdef
@classmethod
def generate_dot_env_sample(cls):
@ -164,9 +234,14 @@ class BindingOptions:
for arg_item in klass.args_env_name_type_value():
if arg_item["help"]:
sample_stream.write(f"# {arg_item['help']}\n")
sample_stream.write(
f"# {arg_item['env_name']}={arg_item['default']}\n\n"
)
# Handle JSON formatting for list types
if arg_item["type"] == List[str]:
default_value = json.dumps(arg_item["default"])
else:
default_value = arg_item["default"]
sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n")
sample_stream.write(sample_bottom)
return sample_stream.getvalue()
@ -256,7 +331,7 @@ class _OllamaOptionsMixin:
seed: int = -1 # Random seed for generation (-1 for random)
# Sampling parameters
temperature: float = 0.8 # Controls randomness (0.0-2.0)
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0-2.0)
top_k: int = 40 # Top-k sampling parameter
top_p: float = 0.9 # Top-p (nucleus) sampling parameter
tfs_z: float = 1.0 # Tail free sampling parameter
@ -295,7 +370,7 @@ class _OllamaOptionsMixin:
# Output control
penalize_newline: bool = True # Penalize newline tokens
stop: str = "" # Stop sequences (comma-separated)
stop: List[str] = field(default_factory=list) # Stop sequences
# optional help strings
_help: ClassVar[dict[str, str]] = {
@ -329,7 +404,7 @@ class _OllamaOptionsMixin:
"use_mlock": "Lock model in memory",
"embedding_only": "Only use for embeddings",
"penalize_newline": "Penalize newline tokens",
"stop": "Stop sequences (comma-separated string)",
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
}
@ -362,9 +437,44 @@ class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions):
class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
"""Options for Ollama LLM with specialized configuration for LLM tasks."""
# Override temperature field to track if it was explicitly set
temperature: float = field(default_factory=lambda: MISSING)
# mandatory name of binding
_binding_name: ClassVar[str] = "ollama_llm"
def __post_init__(self):
"""Handle temperature parameter with correct priority logic"""
# If temperature was not explicitly set, apply priority logic
if self.temperature is MISSING:
# Check OLLAMA_LLM_TEMPERATURE first (highest priority for env vars)
ollama_temp = os.getenv("OLLAMA_LLM_TEMPERATURE")
if ollama_temp is not None:
try:
self.temperature = float(ollama_temp)
logger.debug(f"Using OLLAMA_LLM_TEMPERATURE: {self.temperature}")
return
except (ValueError, TypeError):
logger.warning(
f"Invalid OLLAMA_LLM_TEMPERATURE value: {ollama_temp}"
)
# Check TEMPERATURE as fallback
general_temp = os.getenv("TEMPERATURE")
if general_temp is not None:
try:
self.temperature = float(general_temp)
logger.debug(
f"Using TEMPERATURE environment variable: {self.temperature}"
)
return
except (ValueError, TypeError):
logger.warning(f"Invalid TEMPERATURE value: {general_temp}")
# Use default value
self.temperature = DEFAULT_TEMPERATURE
logger.debug(f"Using default temperature: {self.temperature}")
# =============================================================================
# Additional LLM Provider Bindings
@ -423,15 +533,14 @@ if __name__ == "__main__":
# from io import StringIO
dotenv.load_dotenv(dotenv_path=".env", override=False)
# env_strstream = StringIO(
# ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n")
# )
# # Load environment variables from .env file
# dotenv.load_dotenv(stream=env_strstream)
if len(sys.argv) > 1 and sys.argv[1] == "test":
# Add arguments for OllamaEmbeddingOptions and OllamaLLMOptions
parser = ArgumentParser(description="Test Ollama binding")
OllamaEmbeddingOptions.add_args(parser)
@ -444,8 +553,8 @@ if __name__ == "__main__":
"1024",
"--ollama-llm-num_ctx",
"2048",
"--ollama-llm-stop",
"<end>",
# "--ollama-llm-stop",
# '["</s>", "\\n\\n"]',
]
)
print("Final args for LLM and Embedding:")