From e5e3f0f8783387d3ee24173baedca9ae3f11b134 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 4 Aug 2025 19:43:14 +0800 Subject: [PATCH] Fix(Ollama option): change stop option from string to list and add fallback global temperature setting --- env.example | 8 +- lightrag/llm/binding_options.py | 175 ++++++++++++++++++++++++++------ 2 files changed, 147 insertions(+), 36 deletions(-) diff --git a/env.example b/env.example index 04819679..b926bc38 100644 --- a/env.example +++ b/env.example @@ -121,11 +121,13 @@ LLM_BINDING_HOST=https://api.openai.com/v1 LLM_BINDING_API_KEY=your_api_key ### Most Commont Parameters for Ollama Server -### see also env.ollama-binding-options.example for fine tuning ollama -### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000 -# OLLAMA_LLM_NUM_CTX=32768 ### Time out in seconds, None for infinite timeout TIMEOUT=240 +### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000 +OLLAMA_LLM_NUM_CTX=32768 +### Stop sequences for Ollama LLM +# OLLAMA_LLM_STOP='["", "Assistant:", "\n\n"]' +### see also env.ollama-binding-options.example for fine tuning ollama ### Optional for Azure # AZURE_OPENAI_API_VERSION=2024-08-01-preview diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py index 7791ff56..5bde88a1 100644 --- a/lightrag/llm/binding_options.py +++ b/lightrag/llm/binding_options.py @@ -7,10 +7,13 @@ bindings and integrations. from argparse import ArgumentParser, Namespace import argparse -from dataclasses import asdict, dataclass -from typing import Any, ClassVar +import json +import os +from dataclasses import asdict, dataclass, field, MISSING +from typing import Any, ClassVar, List -from lightrag.utils import get_env_value +from lightrag.utils import get_env_value, logger +from lightrag.constants import DEFAULT_TEMPERATURE # ============================================================================= @@ -96,34 +99,101 @@ class BindingOptions: def add_args(cls, parser: ArgumentParser): group = parser.add_argument_group(f"{cls._binding_name} binding options") for arg_item in cls.args_env_name_type_value(): - group.add_argument( - f"--{arg_item['argname']}", - type=arg_item["type"], - default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS), - help=arg_item["help"], - ) + # Handle JSON parsing for list types + if arg_item["type"] == List[str]: + + def json_list_parser(value): + try: + parsed = json.loads(value) + if not isinstance(parsed, list): + raise argparse.ArgumentTypeError( + f"Expected JSON array, got {type(parsed).__name__}" + ) + return parsed + except json.JSONDecodeError as e: + raise argparse.ArgumentTypeError(f"Invalid JSON: {e}") + + # Get environment variable with JSON parsing + env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS) + if env_value is not argparse.SUPPRESS: + try: + env_value = json_list_parser(env_value) + except argparse.ArgumentTypeError: + env_value = argparse.SUPPRESS + + group.add_argument( + f"--{arg_item['argname']}", + type=json_list_parser, + default=env_value, + help=arg_item["help"], + ) + else: + group.add_argument( + f"--{arg_item['argname']}", + type=arg_item["type"], + default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS), + help=arg_item["help"], + ) @classmethod def args_env_name_type_value(cls): + import dataclasses + args_prefix = f"{cls._binding_name}".replace("_", "-") env_var_prefix = f"{cls._binding_name}_".upper() - class_vars = { - key: value - for key, value in cls._all_class_vars(cls).items() - if not callable(value) and not key.startswith("_") - } help = cls._help - for class_var in class_vars: - argdef = { - "argname": f"{args_prefix}-{class_var}", - "env_name": f"{env_var_prefix}{class_var.upper()}", - "type": type(class_vars[class_var]), - "default": class_vars[class_var], - "help": f"{cls._binding_name} -- " + help.get(class_var, ""), + # Check if this is a dataclass and use dataclass fields + if dataclasses.is_dataclass(cls): + for field in dataclasses.fields(cls): + # Skip private fields + if field.name.startswith("_"): + continue + + # Get default value + if field.default is not dataclasses.MISSING: + default_value = field.default + elif field.default_factory is not dataclasses.MISSING: + default_value = field.default_factory() + else: + default_value = None + + argdef = { + "argname": f"{args_prefix}-{field.name}", + "env_name": f"{env_var_prefix}{field.name.upper()}", + "type": field.type, + "default": default_value, + "help": f"{cls._binding_name} -- " + help.get(field.name, ""), + } + + yield argdef + else: + # Fallback to old method for non-dataclass classes + class_vars = { + key: value + for key, value in cls._all_class_vars(cls).items() + if not callable(value) and not key.startswith("_") } - yield argdef + # Get type hints to properly detect List[str] types + type_hints = {} + for base in cls.__mro__: + if hasattr(base, "__annotations__"): + type_hints.update(base.__annotations__) + + for class_var in class_vars: + # Use type hint if available, otherwise fall back to type of value + var_type = type_hints.get(class_var, type(class_vars[class_var])) + + argdef = { + "argname": f"{args_prefix}-{class_var}", + "env_name": f"{env_var_prefix}{class_var.upper()}", + "type": var_type, + "default": class_vars[class_var], + "help": f"{cls._binding_name} -- " + help.get(class_var, ""), + } + + yield argdef @classmethod def generate_dot_env_sample(cls): @@ -164,9 +234,14 @@ class BindingOptions: for arg_item in klass.args_env_name_type_value(): if arg_item["help"]: sample_stream.write(f"# {arg_item['help']}\n") - sample_stream.write( - f"# {arg_item['env_name']}={arg_item['default']}\n\n" - ) + + # Handle JSON formatting for list types + if arg_item["type"] == List[str]: + default_value = json.dumps(arg_item["default"]) + else: + default_value = arg_item["default"] + + sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n") sample_stream.write(sample_bottom) return sample_stream.getvalue() @@ -256,7 +331,7 @@ class _OllamaOptionsMixin: seed: int = -1 # Random seed for generation (-1 for random) # Sampling parameters - temperature: float = 0.8 # Controls randomness (0.0-2.0) + temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0-2.0) top_k: int = 40 # Top-k sampling parameter top_p: float = 0.9 # Top-p (nucleus) sampling parameter tfs_z: float = 1.0 # Tail free sampling parameter @@ -295,7 +370,7 @@ class _OllamaOptionsMixin: # Output control penalize_newline: bool = True # Penalize newline tokens - stop: str = "" # Stop sequences (comma-separated) + stop: List[str] = field(default_factory=list) # Stop sequences # optional help strings _help: ClassVar[dict[str, str]] = { @@ -329,7 +404,7 @@ class _OllamaOptionsMixin: "use_mlock": "Lock model in memory", "embedding_only": "Only use for embeddings", "penalize_newline": "Penalize newline tokens", - "stop": "Stop sequences (comma-separated string)", + "stop": 'Stop sequences (JSON array of strings, e.g., \'["", "\\n\\n"]\')', } @@ -362,9 +437,44 @@ class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions): class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions): """Options for Ollama LLM with specialized configuration for LLM tasks.""" + # Override temperature field to track if it was explicitly set + temperature: float = field(default_factory=lambda: MISSING) + # mandatory name of binding _binding_name: ClassVar[str] = "ollama_llm" + def __post_init__(self): + """Handle temperature parameter with correct priority logic""" + # If temperature was not explicitly set, apply priority logic + if self.temperature is MISSING: + # Check OLLAMA_LLM_TEMPERATURE first (highest priority for env vars) + ollama_temp = os.getenv("OLLAMA_LLM_TEMPERATURE") + if ollama_temp is not None: + try: + self.temperature = float(ollama_temp) + logger.debug(f"Using OLLAMA_LLM_TEMPERATURE: {self.temperature}") + return + except (ValueError, TypeError): + logger.warning( + f"Invalid OLLAMA_LLM_TEMPERATURE value: {ollama_temp}" + ) + + # Check TEMPERATURE as fallback + general_temp = os.getenv("TEMPERATURE") + if general_temp is not None: + try: + self.temperature = float(general_temp) + logger.debug( + f"Using TEMPERATURE environment variable: {self.temperature}" + ) + return + except (ValueError, TypeError): + logger.warning(f"Invalid TEMPERATURE value: {general_temp}") + + # Use default value + self.temperature = DEFAULT_TEMPERATURE + logger.debug(f"Using default temperature: {self.temperature}") + # ============================================================================= # Additional LLM Provider Bindings @@ -423,15 +533,14 @@ if __name__ == "__main__": # from io import StringIO dotenv.load_dotenv(dotenv_path=".env", override=False) - + # env_strstream = StringIO( # ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n") # ) # # Load environment variables from .env file # dotenv.load_dotenv(stream=env_strstream) - + if len(sys.argv) > 1 and sys.argv[1] == "test": - # Add arguments for OllamaEmbeddingOptions and OllamaLLMOptions parser = ArgumentParser(description="Test Ollama binding") OllamaEmbeddingOptions.add_args(parser) @@ -444,8 +553,8 @@ if __name__ == "__main__": "1024", "--ollama-llm-num_ctx", "2048", - "--ollama-llm-stop", - "", + # "--ollama-llm-stop", + # '["", "\\n\\n"]', ] ) print("Final args for LLM and Embedding:")