Fix(Ollama option): change stop option from string to list and add fallback global temperature setting
This commit is contained in:
parent
f8a880ac66
commit
e5e3f0f878
2 changed files with 147 additions and 36 deletions
|
|
@ -121,11 +121,13 @@ LLM_BINDING_HOST=https://api.openai.com/v1
|
|||
LLM_BINDING_API_KEY=your_api_key
|
||||
|
||||
### Most Commont Parameters for Ollama Server
|
||||
### see also env.ollama-binding-options.example for fine tuning ollama
|
||||
### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000
|
||||
# OLLAMA_LLM_NUM_CTX=32768
|
||||
### Time out in seconds, None for infinite timeout
|
||||
TIMEOUT=240
|
||||
### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000
|
||||
OLLAMA_LLM_NUM_CTX=32768
|
||||
### Stop sequences for Ollama LLM
|
||||
# OLLAMA_LLM_STOP='["</s>", "Assistant:", "\n\n"]'
|
||||
### see also env.ollama-binding-options.example for fine tuning ollama
|
||||
|
||||
### Optional for Azure
|
||||
# AZURE_OPENAI_API_VERSION=2024-08-01-preview
|
||||
|
|
|
|||
|
|
@ -7,10 +7,13 @@ bindings and integrations.
|
|||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import argparse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, ClassVar
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field, MISSING
|
||||
from typing import Any, ClassVar, List
|
||||
|
||||
from lightrag.utils import get_env_value
|
||||
from lightrag.utils import get_env_value, logger
|
||||
from lightrag.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -96,34 +99,101 @@ class BindingOptions:
|
|||
def add_args(cls, parser: ArgumentParser):
|
||||
group = parser.add_argument_group(f"{cls._binding_name} binding options")
|
||||
for arg_item in cls.args_env_name_type_value():
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=arg_item["type"],
|
||||
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
||||
help=arg_item["help"],
|
||||
)
|
||||
# Handle JSON parsing for list types
|
||||
if arg_item["type"] == List[str]:
|
||||
|
||||
def json_list_parser(value):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if not isinstance(parsed, list):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Expected JSON array, got {type(parsed).__name__}"
|
||||
)
|
||||
return parsed
|
||||
except json.JSONDecodeError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
|
||||
|
||||
# Get environment variable with JSON parsing
|
||||
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
|
||||
if env_value is not argparse.SUPPRESS:
|
||||
try:
|
||||
env_value = json_list_parser(env_value)
|
||||
except argparse.ArgumentTypeError:
|
||||
env_value = argparse.SUPPRESS
|
||||
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=json_list_parser,
|
||||
default=env_value,
|
||||
help=arg_item["help"],
|
||||
)
|
||||
else:
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=arg_item["type"],
|
||||
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
||||
help=arg_item["help"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def args_env_name_type_value(cls):
|
||||
import dataclasses
|
||||
|
||||
args_prefix = f"{cls._binding_name}".replace("_", "-")
|
||||
env_var_prefix = f"{cls._binding_name}_".upper()
|
||||
class_vars = {
|
||||
key: value
|
||||
for key, value in cls._all_class_vars(cls).items()
|
||||
if not callable(value) and not key.startswith("_")
|
||||
}
|
||||
help = cls._help
|
||||
|
||||
for class_var in class_vars:
|
||||
argdef = {
|
||||
"argname": f"{args_prefix}-{class_var}",
|
||||
"env_name": f"{env_var_prefix}{class_var.upper()}",
|
||||
"type": type(class_vars[class_var]),
|
||||
"default": class_vars[class_var],
|
||||
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
|
||||
# Check if this is a dataclass and use dataclass fields
|
||||
if dataclasses.is_dataclass(cls):
|
||||
for field in dataclasses.fields(cls):
|
||||
# Skip private fields
|
||||
if field.name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get default value
|
||||
if field.default is not dataclasses.MISSING:
|
||||
default_value = field.default
|
||||
elif field.default_factory is not dataclasses.MISSING:
|
||||
default_value = field.default_factory()
|
||||
else:
|
||||
default_value = None
|
||||
|
||||
argdef = {
|
||||
"argname": f"{args_prefix}-{field.name}",
|
||||
"env_name": f"{env_var_prefix}{field.name.upper()}",
|
||||
"type": field.type,
|
||||
"default": default_value,
|
||||
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
|
||||
}
|
||||
|
||||
yield argdef
|
||||
else:
|
||||
# Fallback to old method for non-dataclass classes
|
||||
class_vars = {
|
||||
key: value
|
||||
for key, value in cls._all_class_vars(cls).items()
|
||||
if not callable(value) and not key.startswith("_")
|
||||
}
|
||||
|
||||
yield argdef
|
||||
# Get type hints to properly detect List[str] types
|
||||
type_hints = {}
|
||||
for base in cls.__mro__:
|
||||
if hasattr(base, "__annotations__"):
|
||||
type_hints.update(base.__annotations__)
|
||||
|
||||
for class_var in class_vars:
|
||||
# Use type hint if available, otherwise fall back to type of value
|
||||
var_type = type_hints.get(class_var, type(class_vars[class_var]))
|
||||
|
||||
argdef = {
|
||||
"argname": f"{args_prefix}-{class_var}",
|
||||
"env_name": f"{env_var_prefix}{class_var.upper()}",
|
||||
"type": var_type,
|
||||
"default": class_vars[class_var],
|
||||
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
|
||||
}
|
||||
|
||||
yield argdef
|
||||
|
||||
@classmethod
|
||||
def generate_dot_env_sample(cls):
|
||||
|
|
@ -164,9 +234,14 @@ class BindingOptions:
|
|||
for arg_item in klass.args_env_name_type_value():
|
||||
if arg_item["help"]:
|
||||
sample_stream.write(f"# {arg_item['help']}\n")
|
||||
sample_stream.write(
|
||||
f"# {arg_item['env_name']}={arg_item['default']}\n\n"
|
||||
)
|
||||
|
||||
# Handle JSON formatting for list types
|
||||
if arg_item["type"] == List[str]:
|
||||
default_value = json.dumps(arg_item["default"])
|
||||
else:
|
||||
default_value = arg_item["default"]
|
||||
|
||||
sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n")
|
||||
|
||||
sample_stream.write(sample_bottom)
|
||||
return sample_stream.getvalue()
|
||||
|
|
@ -256,7 +331,7 @@ class _OllamaOptionsMixin:
|
|||
seed: int = -1 # Random seed for generation (-1 for random)
|
||||
|
||||
# Sampling parameters
|
||||
temperature: float = 0.8 # Controls randomness (0.0-2.0)
|
||||
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0-2.0)
|
||||
top_k: int = 40 # Top-k sampling parameter
|
||||
top_p: float = 0.9 # Top-p (nucleus) sampling parameter
|
||||
tfs_z: float = 1.0 # Tail free sampling parameter
|
||||
|
|
@ -295,7 +370,7 @@ class _OllamaOptionsMixin:
|
|||
|
||||
# Output control
|
||||
penalize_newline: bool = True # Penalize newline tokens
|
||||
stop: str = "" # Stop sequences (comma-separated)
|
||||
stop: List[str] = field(default_factory=list) # Stop sequences
|
||||
|
||||
# optional help strings
|
||||
_help: ClassVar[dict[str, str]] = {
|
||||
|
|
@ -329,7 +404,7 @@ class _OllamaOptionsMixin:
|
|||
"use_mlock": "Lock model in memory",
|
||||
"embedding_only": "Only use for embeddings",
|
||||
"penalize_newline": "Penalize newline tokens",
|
||||
"stop": "Stop sequences (comma-separated string)",
|
||||
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -362,9 +437,44 @@ class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions):
|
|||
class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
|
||||
"""Options for Ollama LLM with specialized configuration for LLM tasks."""
|
||||
|
||||
# Override temperature field to track if it was explicitly set
|
||||
temperature: float = field(default_factory=lambda: MISSING)
|
||||
|
||||
# mandatory name of binding
|
||||
_binding_name: ClassVar[str] = "ollama_llm"
|
||||
|
||||
def __post_init__(self):
|
||||
"""Handle temperature parameter with correct priority logic"""
|
||||
# If temperature was not explicitly set, apply priority logic
|
||||
if self.temperature is MISSING:
|
||||
# Check OLLAMA_LLM_TEMPERATURE first (highest priority for env vars)
|
||||
ollama_temp = os.getenv("OLLAMA_LLM_TEMPERATURE")
|
||||
if ollama_temp is not None:
|
||||
try:
|
||||
self.temperature = float(ollama_temp)
|
||||
logger.debug(f"Using OLLAMA_LLM_TEMPERATURE: {self.temperature}")
|
||||
return
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid OLLAMA_LLM_TEMPERATURE value: {ollama_temp}"
|
||||
)
|
||||
|
||||
# Check TEMPERATURE as fallback
|
||||
general_temp = os.getenv("TEMPERATURE")
|
||||
if general_temp is not None:
|
||||
try:
|
||||
self.temperature = float(general_temp)
|
||||
logger.debug(
|
||||
f"Using TEMPERATURE environment variable: {self.temperature}"
|
||||
)
|
||||
return
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Invalid TEMPERATURE value: {general_temp}")
|
||||
|
||||
# Use default value
|
||||
self.temperature = DEFAULT_TEMPERATURE
|
||||
logger.debug(f"Using default temperature: {self.temperature}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Additional LLM Provider Bindings
|
||||
|
|
@ -423,15 +533,14 @@ if __name__ == "__main__":
|
|||
# from io import StringIO
|
||||
|
||||
dotenv.load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
# env_strstream = StringIO(
|
||||
# ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n")
|
||||
# )
|
||||
# # Load environment variables from .env file
|
||||
# dotenv.load_dotenv(stream=env_strstream)
|
||||
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "test":
|
||||
|
||||
# Add arguments for OllamaEmbeddingOptions and OllamaLLMOptions
|
||||
parser = ArgumentParser(description="Test Ollama binding")
|
||||
OllamaEmbeddingOptions.add_args(parser)
|
||||
|
|
@ -444,8 +553,8 @@ if __name__ == "__main__":
|
|||
"1024",
|
||||
"--ollama-llm-num_ctx",
|
||||
"2048",
|
||||
"--ollama-llm-stop",
|
||||
"<end>",
|
||||
# "--ollama-llm-stop",
|
||||
# '["</s>", "\\n\\n"]',
|
||||
]
|
||||
)
|
||||
print("Final args for LLM and Embedding:")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue