fix: Resolve default rerank config problem when env var missing
- Read config from selected_rerank_func when env var missing - Make api_key optional for rerank function - Add response format validation with proper error handling - Update Cohere rerank default to official API endpoint
This commit is contained in:
parent
580cb7906c
commit
bf43e1b8c1
8 changed files with 50 additions and 142 deletions
|
|
@ -35,6 +35,7 @@ from lightrag.constants import (
|
|||
DEFAULT_EMBEDDING_BATCH_NUM,
|
||||
DEFAULT_OLLAMA_MODEL_NAME,
|
||||
DEFAULT_OLLAMA_MODEL_TAG,
|
||||
DEFAULT_RERANK_BINDING,
|
||||
)
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
|
|
@ -76,9 +77,7 @@ def parse_args() -> argparse.Namespace:
|
|||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="LightRAG API Server")
|
||||
|
||||
# Server configuration
|
||||
parser.add_argument(
|
||||
|
|
@ -228,15 +227,15 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--rerank-binding",
|
||||
type=str,
|
||||
default=get_env_value("RERANK_BINDING", "cohere"),
|
||||
default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING),
|
||||
choices=["cohere", "jina", "aliyun"],
|
||||
help="Rerank binding type (default: from env or cohere)",
|
||||
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-rerank",
|
||||
action="store_true",
|
||||
default=get_env_value("ENABLE_RERANK", True, bool),
|
||||
help="Enable rerank functionality (default: from env or True)",
|
||||
default=get_env_value("ENABLE_RERANK", False, bool),
|
||||
help="Enable rerank functionality (default: from env or disalbed)",
|
||||
)
|
||||
|
||||
# Conditionally add binding options defined in binding_options module
|
||||
|
|
@ -350,7 +349,7 @@ def parse_args() -> argparse.Namespace:
|
|||
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
||||
|
||||
# Rerank model configuration
|
||||
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
args.rerank_model = get_env_value("RERANK_MODEL", None)
|
||||
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
||||
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
||||
# Note: rerank_binding is already set by argparse, no need to override from env
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import signal
|
|||
import sys
|
||||
import uvicorn
|
||||
import pipmaster as pm
|
||||
import inspect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pathlib import Path
|
||||
|
|
@ -408,6 +409,22 @@ def create_app(args):
|
|||
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||
|
||||
# Get default values from selected_rerank_func if args values are None
|
||||
if args.rerank_model is None or args.rerank_binding_host is None:
|
||||
sig = inspect.signature(selected_rerank_func)
|
||||
|
||||
# Set default model if args.rerank_model is None
|
||||
if args.rerank_model is None and "model" in sig.parameters:
|
||||
default_model = sig.parameters["model"].default
|
||||
if default_model != inspect.Parameter.empty:
|
||||
args.rerank_model = default_model
|
||||
|
||||
# Set default base_url if args.rerank_binding_host is None
|
||||
if args.rerank_binding_host is None and "base_url" in sig.parameters:
|
||||
default_base_url = sig.parameters["base_url"].default
|
||||
if default_base_url != inspect.Parameter.empty:
|
||||
args.rerank_binding_host = default_base_url
|
||||
|
||||
async def server_rerank_func(
|
||||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||
):
|
||||
|
|
@ -415,19 +432,19 @@ def create_app(args):
|
|||
return await selected_rerank_func(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
api_key=args.rerank_binding_api_key,
|
||||
model=args.rerank_model,
|
||||
base_url=args.rerank_binding_host,
|
||||
api_key=args.rerank_binding_api_key,
|
||||
top_n=top_n,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
rerank_model_func = server_rerank_func
|
||||
logger.info(
|
||||
f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider"
|
||||
f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
|
||||
)
|
||||
else:
|
||||
logger.info("Rerank disabled")
|
||||
logger.info("Reranking is disabled")
|
||||
|
||||
# Create ollama_server_infos from command line arguments
|
||||
from lightrag.api.config import OllamaServerInfos
|
||||
|
|
@ -635,7 +652,6 @@ def create_app(args):
|
|||
"max_graph_nodes": args.max_graph_nodes,
|
||||
# Rerank configuration
|
||||
"enable_rerank": args.enable_rerank,
|
||||
"rerank_configured": rerank_model_func is not None,
|
||||
"rerank_binding": args.rerank_binding
|
||||
if args.enable_rerank
|
||||
else None,
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from .constants import (
|
|||
DEFAULT_MAX_RELATION_TOKENS,
|
||||
DEFAULT_MAX_TOTAL_TOKENS,
|
||||
DEFAULT_HISTORY_TURNS,
|
||||
DEFAULT_ENABLE_RERANK,
|
||||
DEFAULT_OLLAMA_MODEL_NAME,
|
||||
DEFAULT_OLLAMA_MODEL_TAG,
|
||||
DEFAULT_OLLAMA_MODEL_SIZE,
|
||||
|
|
@ -158,9 +157,7 @@ class QueryParam:
|
|||
If proivded, this will be use instead of the default vaulue from prompt template.
|
||||
"""
|
||||
|
||||
enable_rerank: bool = (
|
||||
os.getenv("ENABLE_RERANK", str(DEFAULT_ENABLE_RERANK).lower()).lower() == "true"
|
||||
)
|
||||
enable_rerank: bool = os.getenv("ENABLE_RERANK", "false").lower() == "true"
|
||||
"""Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued.
|
||||
Default is True to enable reranking when rerank model is available.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
|
|||
DEFAULT_HISTORY_TURNS = 0
|
||||
|
||||
# Rerank configuration defaults
|
||||
DEFAULT_ENABLE_RERANK = True
|
||||
DEFAULT_MIN_RERANK_SCORE = 0.0
|
||||
DEFAULT_RERANK_BINDING = "cohere"
|
||||
|
||||
# File path configuration for vector and graph database(Should not be changed, used in Milvus Schema)
|
||||
DEFAULT_MAX_FILE_PATH_LENGTH = 32768
|
||||
|
|
|
|||
|
|
@ -525,14 +525,6 @@ class LightRAG:
|
|||
)
|
||||
)
|
||||
|
||||
# Init Rerank
|
||||
if self.rerank_model_func:
|
||||
logger.info("Rerank model initialized for improved retrieval quality")
|
||||
else:
|
||||
logger.warning(
|
||||
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
|
||||
)
|
||||
|
||||
self._storages_status = StoragesStatus.CREATED
|
||||
|
||||
async def initialize_storages(self):
|
||||
|
|
|
|||
101
lightrag/llm.py
101
lightrag/llm.py
|
|
@ -1,101 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
"""
|
||||
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
||||
|
||||
Attributes:
|
||||
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
||||
The function should take any argument and return a string.
|
||||
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
||||
This could include parameters such as the model name, API key, etc.
|
||||
|
||||
Example usage:
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
||||
|
||||
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
||||
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
||||
"""
|
||||
|
||||
gen_func: Callable[[Any], str] = Field(
|
||||
...,
|
||||
description="A function that generates the response from the llm. The response must be a string",
|
||||
)
|
||||
kwargs: dict[str, Any] = Field(
|
||||
...,
|
||||
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class MultiModel:
|
||||
"""
|
||||
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
||||
Could also be used for spliting across diffrent models or providers.
|
||||
|
||||
Attributes:
|
||||
models (List[Model]): A list of language models to be used.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
models = [
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
||||
]
|
||||
multi_model = MultiModel(models)
|
||||
rag = LightRAG(
|
||||
llm_model_func=multi_model.llm_model_func
|
||||
/ ..other args
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, models: list[Model]):
|
||||
self._models = models
|
||||
self._current_model = 0
|
||||
|
||||
def _next_model(self):
|
||||
self._current_model = (self._current_model + 1) % len(self._models)
|
||||
return self._models[self._current_model]
|
||||
|
||||
async def llm_model_func(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] = [],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("mode", None)
|
||||
next_model = self._next_model()
|
||||
args = dict(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
**next_model.kwargs,
|
||||
)
|
||||
|
||||
return await next_model.gen_func(**args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
from lightrag.llm.openai import gpt_4o_mini_complete
|
||||
|
||||
result = await gpt_4o_mini_complete("How are you?")
|
||||
print(result)
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -32,7 +32,7 @@ async def generic_rerank_api(
|
|||
documents: List[str],
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
api_key: Optional[str],
|
||||
top_n: Optional[int] = None,
|
||||
return_documents: Optional[bool] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -56,13 +56,12 @@ async def generic_rerank_api(
|
|||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API key is required")
|
||||
if not base_url:
|
||||
raise ValueError("Base URL is required")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Build request payload based on request format
|
||||
if request_format == "aliyun":
|
||||
|
|
@ -119,7 +118,6 @@ async def generic_rerank_api(
|
|||
error_text.strip().startswith("<!DOCTYPE html>")
|
||||
or "text/html" in content_type
|
||||
)
|
||||
|
||||
if is_html_error:
|
||||
if response.status == 502:
|
||||
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
|
||||
|
|
@ -131,7 +129,6 @@ async def generic_rerank_api(
|
|||
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
|
||||
else:
|
||||
clean_error = error_text
|
||||
|
||||
logger.error(f"Rerank API error {response.status}: {clean_error}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
|
|
@ -142,17 +139,25 @@ async def generic_rerank_api(
|
|||
|
||||
response_json = await response.json()
|
||||
|
||||
# Handle different response formats
|
||||
if response_format == "aliyun":
|
||||
# Aliyun format: {"output": {"results": [...]}}
|
||||
output = response_json.get("output", {})
|
||||
results = output.get("results", [])
|
||||
results = response_json.get("output", {}).get("results", [])
|
||||
if not isinstance(results, list):
|
||||
logger.warning(
|
||||
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
||||
)
|
||||
results = []
|
||||
|
||||
elif response_format == "standard":
|
||||
# Standard format: {"results": [...]}
|
||||
results = response_json.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
logger.warning(
|
||||
f"Expected 'results' to be list, got {type(results)}: {results}"
|
||||
)
|
||||
results = []
|
||||
else:
|
||||
raise ValueError(f"Unsupported response format: {response_format}")
|
||||
|
||||
if not results:
|
||||
logger.warning("Rerank API returned empty results")
|
||||
return []
|
||||
|
|
@ -170,7 +175,7 @@ async def cohere_rerank(
|
|||
top_n: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "rerank-v3.5",
|
||||
base_url: str = "https://ai.znipower.com:5017/rerank",
|
||||
base_url: str = "https://api.cohere.com/v2/rerank",
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1995,7 +1995,7 @@ async def apply_rerank_if_enabled(
|
|||
rerank_results = await rerank_func(
|
||||
query=query,
|
||||
documents=document_texts,
|
||||
top_n=top_n or len(retrieved_docs),
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
# Process rerank results based on return format
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue