fix: Resolve default rerank config problem when env var missing

- Read config from selected_rerank_func when env var missing
- Make api_key optional for rerank function
- Add response format validation with proper error handling
- Update Cohere rerank default to official API endpoint
This commit is contained in:
yangdx 2025-08-23 01:07:59 +08:00
parent 580cb7906c
commit bf43e1b8c1
8 changed files with 50 additions and 142 deletions

View file

@ -35,6 +35,7 @@ from lightrag.constants import (
DEFAULT_EMBEDDING_BATCH_NUM,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_RERANK_BINDING,
)
# use the .env that is inside the current folder
@ -76,9 +77,7 @@ def parse_args() -> argparse.Namespace:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
parser = argparse.ArgumentParser(description="LightRAG API Server")
# Server configuration
parser.add_argument(
@ -228,15 +227,15 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--rerank-binding",
type=str,
default=get_env_value("RERANK_BINDING", "cohere"),
default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING),
choices=["cohere", "jina", "aliyun"],
help="Rerank binding type (default: from env or cohere)",
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
)
parser.add_argument(
"--enable-rerank",
action="store_true",
default=get_env_value("ENABLE_RERANK", True, bool),
help="Enable rerank functionality (default: from env or True)",
default=get_env_value("ENABLE_RERANK", False, bool),
help="Enable rerank functionality (default: from env or disalbed)",
)
# Conditionally add binding options defined in binding_options module
@ -350,7 +349,7 @@ def parse_args() -> argparse.Namespace:
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Rerank model configuration
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
args.rerank_model = get_env_value("RERANK_MODEL", None)
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
# Note: rerank_binding is already set by argparse, no need to override from env

View file

@ -11,6 +11,7 @@ import signal
import sys
import uvicorn
import pipmaster as pm
import inspect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
@ -408,6 +409,22 @@ def create_app(args):
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
# Get default values from selected_rerank_func if args values are None
if args.rerank_model is None or args.rerank_binding_host is None:
sig = inspect.signature(selected_rerank_func)
# Set default model if args.rerank_model is None
if args.rerank_model is None and "model" in sig.parameters:
default_model = sig.parameters["model"].default
if default_model != inspect.Parameter.empty:
args.rerank_model = default_model
# Set default base_url if args.rerank_binding_host is None
if args.rerank_binding_host is None and "base_url" in sig.parameters:
default_base_url = sig.parameters["base_url"].default
if default_base_url != inspect.Parameter.empty:
args.rerank_binding_host = default_base_url
async def server_rerank_func(
query: str, documents: list, top_n: int = None, extra_body: dict = None
):
@ -415,19 +432,19 @@ def create_app(args):
return await selected_rerank_func(
query=query,
documents=documents,
top_n=top_n,
api_key=args.rerank_binding_api_key,
model=args.rerank_model,
base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key,
top_n=top_n,
extra_body=extra_body,
)
rerank_model_func = server_rerank_func
logger.info(
f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider"
f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
)
else:
logger.info("Rerank disabled")
logger.info("Reranking is disabled")
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
@ -635,7 +652,6 @@ def create_app(args):
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": args.enable_rerank,
"rerank_configured": rerank_model_func is not None,
"rerank_binding": args.rerank_binding
if args.enable_rerank
else None,

View file

@ -22,7 +22,6 @@ from .constants import (
DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_HISTORY_TURNS,
DEFAULT_ENABLE_RERANK,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_OLLAMA_MODEL_SIZE,
@ -158,9 +157,7 @@ class QueryParam:
If proivded, this will be use instead of the default vaulue from prompt template.
"""
enable_rerank: bool = (
os.getenv("ENABLE_RERANK", str(DEFAULT_ENABLE_RERANK).lower()).lower() == "true"
)
enable_rerank: bool = os.getenv("ENABLE_RERANK", "false").lower() == "true"
"""Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued.
Default is True to enable reranking when rerank model is available.
"""

View file

@ -32,8 +32,8 @@ DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
DEFAULT_HISTORY_TURNS = 0
# Rerank configuration defaults
DEFAULT_ENABLE_RERANK = True
DEFAULT_MIN_RERANK_SCORE = 0.0
DEFAULT_RERANK_BINDING = "cohere"
# File path configuration for vector and graph database(Should not be changed, used in Milvus Schema)
DEFAULT_MAX_FILE_PATH_LENGTH = 32768

View file

@ -525,14 +525,6 @@ class LightRAG:
)
)
# Init Rerank
if self.rerank_model_func:
logger.info("Rerank model initialized for improved retrieval quality")
else:
logger.warning(
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
)
self._storages_status = StoragesStatus.CREATED
async def initialize_storages(self):

View file

@ -1,101 +0,0 @@
from __future__ import annotations
from typing import Callable, Any
from pydantic import BaseModel, Field
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
Attributes:
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
The function should take any argument and return a string.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This could include parameters such as the model name, API key, etc.
Example usage:
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
gen_func: Callable[[Any], str] = Field(
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
class Config:
arbitrary_types_allowed = True
class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
Attributes:
models (List[Model]): A list of language models to be used.
Usage example:
```python
models = [
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
]
multi_model = MultiModel(models)
rag = LightRAG(
llm_model_func=multi_model.llm_model_func
/ ..other args
)
```
"""
def __init__(self, models: list[Model]):
self._models = models
self._current_model = 0
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
self,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] = [],
**kwargs: Any,
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None)
kwargs.pop("mode", None)
next_model = self._next_model()
args = dict(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
**next_model.kwargs,
)
return await next_model.gen_func(**args)
if __name__ == "__main__":
import asyncio
async def main():
from lightrag.llm.openai import gpt_4o_mini_complete
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())

View file

@ -32,7 +32,7 @@ async def generic_rerank_api(
documents: List[str],
model: str,
base_url: str,
api_key: str,
api_key: Optional[str],
top_n: Optional[int] = None,
return_documents: Optional[bool] = None,
extra_body: Optional[Dict[str, Any]] = None,
@ -56,13 +56,12 @@ async def generic_rerank_api(
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
if not api_key:
raise ValueError("API key is required")
if not base_url:
raise ValueError("Base URL is required")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Build request payload based on request format
if request_format == "aliyun":
@ -119,7 +118,6 @@ async def generic_rerank_api(
error_text.strip().startswith("<!DOCTYPE html>")
or "text/html" in content_type
)
if is_html_error:
if response.status == 502:
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
@ -131,7 +129,6 @@ async def generic_rerank_api(
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
else:
clean_error = error_text
logger.error(f"Rerank API error {response.status}: {clean_error}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
@ -142,17 +139,25 @@ async def generic_rerank_api(
response_json = await response.json()
# Handle different response formats
if response_format == "aliyun":
# Aliyun format: {"output": {"results": [...]}}
output = response_json.get("output", {})
results = output.get("results", [])
results = response_json.get("output", {}).get("results", [])
if not isinstance(results, list):
logger.warning(
f"Expected 'output.results' to be list, got {type(results)}: {results}"
)
results = []
elif response_format == "standard":
# Standard format: {"results": [...]}
results = response_json.get("results", [])
if not isinstance(results, list):
logger.warning(
f"Expected 'results' to be list, got {type(results)}: {results}"
)
results = []
else:
raise ValueError(f"Unsupported response format: {response_format}")
if not results:
logger.warning("Rerank API returned empty results")
return []
@ -170,7 +175,7 @@ async def cohere_rerank(
top_n: Optional[int] = None,
api_key: Optional[str] = None,
model: str = "rerank-v3.5",
base_url: str = "https://ai.znipower.com:5017/rerank",
base_url: str = "https://api.cohere.com/v2/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""

View file

@ -1995,7 +1995,7 @@ async def apply_rerank_if_enabled(
rerank_results = await rerank_func(
query=query,
documents=document_texts,
top_n=top_n or len(retrieved_docs),
top_n=top_n,
)
# Process rerank results based on return format