fix: Resolve default rerank config problem when env var missing
- Read config from selected_rerank_func when env var missing - Make api_key optional for rerank function - Add response format validation with proper error handling - Update Cohere rerank default to official API endpoint
This commit is contained in:
parent
580cb7906c
commit
bf43e1b8c1
8 changed files with 50 additions and 142 deletions
|
|
@ -35,6 +35,7 @@ from lightrag.constants import (
|
||||||
DEFAULT_EMBEDDING_BATCH_NUM,
|
DEFAULT_EMBEDDING_BATCH_NUM,
|
||||||
DEFAULT_OLLAMA_MODEL_NAME,
|
DEFAULT_OLLAMA_MODEL_NAME,
|
||||||
DEFAULT_OLLAMA_MODEL_TAG,
|
DEFAULT_OLLAMA_MODEL_TAG,
|
||||||
|
DEFAULT_RERANK_BINDING,
|
||||||
)
|
)
|
||||||
|
|
||||||
# use the .env that is inside the current folder
|
# use the .env that is inside the current folder
|
||||||
|
|
@ -76,9 +77,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
argparse.Namespace: Parsed arguments
|
argparse.Namespace: Parsed arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="LightRAG API Server")
|
||||||
description="LightRAG FastAPI Server with separate working and input directories"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Server configuration
|
# Server configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -228,15 +227,15 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rerank-binding",
|
"--rerank-binding",
|
||||||
type=str,
|
type=str,
|
||||||
default=get_env_value("RERANK_BINDING", "cohere"),
|
default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING),
|
||||||
choices=["cohere", "jina", "aliyun"],
|
choices=["cohere", "jina", "aliyun"],
|
||||||
help="Rerank binding type (default: from env or cohere)",
|
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-rerank",
|
"--enable-rerank",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=get_env_value("ENABLE_RERANK", True, bool),
|
default=get_env_value("ENABLE_RERANK", False, bool),
|
||||||
help="Enable rerank functionality (default: from env or True)",
|
help="Enable rerank functionality (default: from env or disalbed)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Conditionally add binding options defined in binding_options module
|
# Conditionally add binding options defined in binding_options module
|
||||||
|
|
@ -350,7 +349,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
||||||
|
|
||||||
# Rerank model configuration
|
# Rerank model configuration
|
||||||
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
args.rerank_model = get_env_value("RERANK_MODEL", None)
|
||||||
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
||||||
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
||||||
# Note: rerank_binding is already set by argparse, no need to override from env
|
# Note: rerank_binding is already set by argparse, no need to override from env
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import signal
|
||||||
import sys
|
import sys
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
import inspect
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -408,6 +409,22 @@ def create_app(args):
|
||||||
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
|
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||||
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
|
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||||
|
|
||||||
|
# Get default values from selected_rerank_func if args values are None
|
||||||
|
if args.rerank_model is None or args.rerank_binding_host is None:
|
||||||
|
sig = inspect.signature(selected_rerank_func)
|
||||||
|
|
||||||
|
# Set default model if args.rerank_model is None
|
||||||
|
if args.rerank_model is None and "model" in sig.parameters:
|
||||||
|
default_model = sig.parameters["model"].default
|
||||||
|
if default_model != inspect.Parameter.empty:
|
||||||
|
args.rerank_model = default_model
|
||||||
|
|
||||||
|
# Set default base_url if args.rerank_binding_host is None
|
||||||
|
if args.rerank_binding_host is None and "base_url" in sig.parameters:
|
||||||
|
default_base_url = sig.parameters["base_url"].default
|
||||||
|
if default_base_url != inspect.Parameter.empty:
|
||||||
|
args.rerank_binding_host = default_base_url
|
||||||
|
|
||||||
async def server_rerank_func(
|
async def server_rerank_func(
|
||||||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||||
):
|
):
|
||||||
|
|
@ -415,19 +432,19 @@ def create_app(args):
|
||||||
return await selected_rerank_func(
|
return await selected_rerank_func(
|
||||||
query=query,
|
query=query,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
top_n=top_n,
|
||||||
|
api_key=args.rerank_binding_api_key,
|
||||||
model=args.rerank_model,
|
model=args.rerank_model,
|
||||||
base_url=args.rerank_binding_host,
|
base_url=args.rerank_binding_host,
|
||||||
api_key=args.rerank_binding_api_key,
|
|
||||||
top_n=top_n,
|
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
rerank_model_func = server_rerank_func
|
rerank_model_func = server_rerank_func
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider"
|
f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Rerank disabled")
|
logger.info("Reranking is disabled")
|
||||||
|
|
||||||
# Create ollama_server_infos from command line arguments
|
# Create ollama_server_infos from command line arguments
|
||||||
from lightrag.api.config import OllamaServerInfos
|
from lightrag.api.config import OllamaServerInfos
|
||||||
|
|
@ -635,7 +652,6 @@ def create_app(args):
|
||||||
"max_graph_nodes": args.max_graph_nodes,
|
"max_graph_nodes": args.max_graph_nodes,
|
||||||
# Rerank configuration
|
# Rerank configuration
|
||||||
"enable_rerank": args.enable_rerank,
|
"enable_rerank": args.enable_rerank,
|
||||||
"rerank_configured": rerank_model_func is not None,
|
|
||||||
"rerank_binding": args.rerank_binding
|
"rerank_binding": args.rerank_binding
|
||||||
if args.enable_rerank
|
if args.enable_rerank
|
||||||
else None,
|
else None,
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ from .constants import (
|
||||||
DEFAULT_MAX_RELATION_TOKENS,
|
DEFAULT_MAX_RELATION_TOKENS,
|
||||||
DEFAULT_MAX_TOTAL_TOKENS,
|
DEFAULT_MAX_TOTAL_TOKENS,
|
||||||
DEFAULT_HISTORY_TURNS,
|
DEFAULT_HISTORY_TURNS,
|
||||||
DEFAULT_ENABLE_RERANK,
|
|
||||||
DEFAULT_OLLAMA_MODEL_NAME,
|
DEFAULT_OLLAMA_MODEL_NAME,
|
||||||
DEFAULT_OLLAMA_MODEL_TAG,
|
DEFAULT_OLLAMA_MODEL_TAG,
|
||||||
DEFAULT_OLLAMA_MODEL_SIZE,
|
DEFAULT_OLLAMA_MODEL_SIZE,
|
||||||
|
|
@ -158,9 +157,7 @@ class QueryParam:
|
||||||
If proivded, this will be use instead of the default vaulue from prompt template.
|
If proivded, this will be use instead of the default vaulue from prompt template.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
enable_rerank: bool = (
|
enable_rerank: bool = os.getenv("ENABLE_RERANK", "false").lower() == "true"
|
||||||
os.getenv("ENABLE_RERANK", str(DEFAULT_ENABLE_RERANK).lower()).lower() == "true"
|
|
||||||
)
|
|
||||||
"""Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued.
|
"""Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued.
|
||||||
Default is True to enable reranking when rerank model is available.
|
Default is True to enable reranking when rerank model is available.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,8 @@ DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
|
||||||
DEFAULT_HISTORY_TURNS = 0
|
DEFAULT_HISTORY_TURNS = 0
|
||||||
|
|
||||||
# Rerank configuration defaults
|
# Rerank configuration defaults
|
||||||
DEFAULT_ENABLE_RERANK = True
|
|
||||||
DEFAULT_MIN_RERANK_SCORE = 0.0
|
DEFAULT_MIN_RERANK_SCORE = 0.0
|
||||||
|
DEFAULT_RERANK_BINDING = "cohere"
|
||||||
|
|
||||||
# File path configuration for vector and graph database(Should not be changed, used in Milvus Schema)
|
# File path configuration for vector and graph database(Should not be changed, used in Milvus Schema)
|
||||||
DEFAULT_MAX_FILE_PATH_LENGTH = 32768
|
DEFAULT_MAX_FILE_PATH_LENGTH = 32768
|
||||||
|
|
|
||||||
|
|
@ -525,14 +525,6 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init Rerank
|
|
||||||
if self.rerank_model_func:
|
|
||||||
logger.info("Rerank model initialized for improved retrieval quality")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._storages_status = StoragesStatus.CREATED
|
self._storages_status = StoragesStatus.CREATED
|
||||||
|
|
||||||
async def initialize_storages(self):
|
async def initialize_storages(self):
|
||||||
|
|
|
||||||
101
lightrag/llm.py
101
lightrag/llm.py
|
|
@ -1,101 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Callable, Any
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class Model(BaseModel):
|
|
||||||
"""
|
|
||||||
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
|
||||||
The function should take any argument and return a string.
|
|
||||||
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
|
||||||
This could include parameters such as the model name, API key, etc.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
|
||||||
|
|
||||||
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
|
||||||
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
gen_func: Callable[[Any], str] = Field(
|
|
||||||
...,
|
|
||||||
description="A function that generates the response from the llm. The response must be a string",
|
|
||||||
)
|
|
||||||
kwargs: dict[str, Any] = Field(
|
|
||||||
...,
|
|
||||||
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModel:
|
|
||||||
"""
|
|
||||||
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
|
||||||
Could also be used for spliting across diffrent models or providers.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
models (List[Model]): A list of language models to be used.
|
|
||||||
|
|
||||||
Usage example:
|
|
||||||
```python
|
|
||||||
models = [
|
|
||||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
|
||||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
|
||||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
|
||||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
|
||||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
|
||||||
]
|
|
||||||
multi_model = MultiModel(models)
|
|
||||||
rag = LightRAG(
|
|
||||||
llm_model_func=multi_model.llm_model_func
|
|
||||||
/ ..other args
|
|
||||||
)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, models: list[Model]):
|
|
||||||
self._models = models
|
|
||||||
self._current_model = 0
|
|
||||||
|
|
||||||
def _next_model(self):
|
|
||||||
self._current_model = (self._current_model + 1) % len(self._models)
|
|
||||||
return self._models[self._current_model]
|
|
||||||
|
|
||||||
async def llm_model_func(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
history_messages: list[dict[str, Any]] = [],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
|
||||||
kwargs.pop("keyword_extraction", None)
|
|
||||||
kwargs.pop("mode", None)
|
|
||||||
next_model = self._next_model()
|
|
||||||
args = dict(
|
|
||||||
prompt=prompt,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
history_messages=history_messages,
|
|
||||||
**kwargs,
|
|
||||||
**next_model.kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await next_model.gen_func(**args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
from lightrag.llm.openai import gpt_4o_mini_complete
|
|
||||||
|
|
||||||
result = await gpt_4o_mini_complete("How are you?")
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
@ -32,7 +32,7 @@ async def generic_rerank_api(
|
||||||
documents: List[str],
|
documents: List[str],
|
||||||
model: str,
|
model: str,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
api_key: str,
|
api_key: Optional[str],
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
return_documents: Optional[bool] = None,
|
return_documents: Optional[bool] = None,
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -56,13 +56,12 @@ async def generic_rerank_api(
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionary of ["index": int, "relevance_score": float]
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
"""
|
"""
|
||||||
if not api_key:
|
if not base_url:
|
||||||
raise ValueError("API key is required")
|
raise ValueError("Base URL is required")
|
||||||
|
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
"Content-Type": "application/json",
|
if api_key is not None:
|
||||||
"Authorization": f"Bearer {api_key}",
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
}
|
|
||||||
|
|
||||||
# Build request payload based on request format
|
# Build request payload based on request format
|
||||||
if request_format == "aliyun":
|
if request_format == "aliyun":
|
||||||
|
|
@ -119,7 +118,6 @@ async def generic_rerank_api(
|
||||||
error_text.strip().startswith("<!DOCTYPE html>")
|
error_text.strip().startswith("<!DOCTYPE html>")
|
||||||
or "text/html" in content_type
|
or "text/html" in content_type
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_html_error:
|
if is_html_error:
|
||||||
if response.status == 502:
|
if response.status == 502:
|
||||||
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
|
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
|
||||||
|
|
@ -131,7 +129,6 @@ async def generic_rerank_api(
|
||||||
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
|
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
|
||||||
else:
|
else:
|
||||||
clean_error = error_text
|
clean_error = error_text
|
||||||
|
|
||||||
logger.error(f"Rerank API error {response.status}: {clean_error}")
|
logger.error(f"Rerank API error {response.status}: {clean_error}")
|
||||||
raise aiohttp.ClientResponseError(
|
raise aiohttp.ClientResponseError(
|
||||||
request_info=response.request_info,
|
request_info=response.request_info,
|
||||||
|
|
@ -142,17 +139,25 @@ async def generic_rerank_api(
|
||||||
|
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
|
|
||||||
# Handle different response formats
|
|
||||||
if response_format == "aliyun":
|
if response_format == "aliyun":
|
||||||
# Aliyun format: {"output": {"results": [...]}}
|
# Aliyun format: {"output": {"results": [...]}}
|
||||||
output = response_json.get("output", {})
|
results = response_json.get("output", {}).get("results", [])
|
||||||
results = output.get("results", [])
|
if not isinstance(results, list):
|
||||||
|
logger.warning(
|
||||||
|
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
|
||||||
elif response_format == "standard":
|
elif response_format == "standard":
|
||||||
# Standard format: {"results": [...]}
|
# Standard format: {"results": [...]}
|
||||||
results = response_json.get("results", [])
|
results = response_json.get("results", [])
|
||||||
|
if not isinstance(results, list):
|
||||||
|
logger.warning(
|
||||||
|
f"Expected 'results' to be list, got {type(results)}: {results}"
|
||||||
|
)
|
||||||
|
results = []
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported response format: {response_format}")
|
raise ValueError(f"Unsupported response format: {response_format}")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
logger.warning("Rerank API returned empty results")
|
logger.warning("Rerank API returned empty results")
|
||||||
return []
|
return []
|
||||||
|
|
@ -170,7 +175,7 @@ async def cohere_rerank(
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
model: str = "rerank-v3.5",
|
model: str = "rerank-v3.5",
|
||||||
base_url: str = "https://ai.znipower.com:5017/rerank",
|
base_url: str = "https://api.cohere.com/v2/rerank",
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1995,7 +1995,7 @@ async def apply_rerank_if_enabled(
|
||||||
rerank_results = await rerank_func(
|
rerank_results = await rerank_func(
|
||||||
query=query,
|
query=query,
|
||||||
documents=document_texts,
|
documents=document_texts,
|
||||||
top_n=top_n or len(retrieved_docs),
|
top_n=top_n,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process rerank results based on return format
|
# Process rerank results based on return format
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue