diff --git a/lightrag/api/config.py b/lightrag/api/config.py index bc24cd70..92952ec4 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -35,6 +35,7 @@ from lightrag.constants import ( DEFAULT_EMBEDDING_BATCH_NUM, DEFAULT_OLLAMA_MODEL_NAME, DEFAULT_OLLAMA_MODEL_TAG, + DEFAULT_RERANK_BINDING, ) # use the .env that is inside the current folder @@ -76,9 +77,7 @@ def parse_args() -> argparse.Namespace: argparse.Namespace: Parsed arguments """ - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) + parser = argparse.ArgumentParser(description="LightRAG API Server") # Server configuration parser.add_argument( @@ -228,15 +227,15 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--rerank-binding", type=str, - default=get_env_value("RERANK_BINDING", "cohere"), + default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING), choices=["cohere", "jina", "aliyun"], - help="Rerank binding type (default: from env or cohere)", + help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})", ) parser.add_argument( "--enable-rerank", action="store_true", - default=get_env_value("ENABLE_RERANK", True, bool), - help="Enable rerank functionality (default: from env or True)", + default=get_env_value("ENABLE_RERANK", False, bool), + help="Enable rerank functionality (default: from env or disalbed)", ) # Conditionally add binding options defined in binding_options module @@ -350,7 +349,7 @@ def parse_args() -> argparse.Namespace: args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") # Rerank model configuration - args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") + args.rerank_model = get_env_value("RERANK_MODEL", None) args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) # Note: rerank_binding is already set by argparse, no need to override from env diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8e3f9af1..8214d601 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -11,6 +11,7 @@ import signal import sys import uvicorn import pipmaster as pm +import inspect from fastapi.staticfiles import StaticFiles from fastapi.responses import RedirectResponse from pathlib import Path @@ -408,6 +409,22 @@ def create_app(args): logger.error(f"Unsupported rerank binding: {args.rerank_binding}") raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}") + # Get default values from selected_rerank_func if args values are None + if args.rerank_model is None or args.rerank_binding_host is None: + sig = inspect.signature(selected_rerank_func) + + # Set default model if args.rerank_model is None + if args.rerank_model is None and "model" in sig.parameters: + default_model = sig.parameters["model"].default + if default_model != inspect.Parameter.empty: + args.rerank_model = default_model + + # Set default base_url if args.rerank_binding_host is None + if args.rerank_binding_host is None and "base_url" in sig.parameters: + default_base_url = sig.parameters["base_url"].default + if default_base_url != inspect.Parameter.empty: + args.rerank_binding_host = default_base_url + async def server_rerank_func( query: str, documents: list, top_n: int = None, extra_body: dict = None ): @@ -415,19 +432,19 @@ def create_app(args): return await selected_rerank_func( query=query, documents=documents, + top_n=top_n, + api_key=args.rerank_binding_api_key, model=args.rerank_model, base_url=args.rerank_binding_host, - api_key=args.rerank_binding_api_key, - top_n=top_n, extra_body=extra_body, ) rerank_model_func = server_rerank_func logger.info( - f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider" + f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider" ) else: - logger.info("Rerank disabled") + logger.info("Reranking is disabled") # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos @@ -635,7 +652,6 @@ def create_app(args): "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": args.enable_rerank, - "rerank_configured": rerank_model_func is not None, "rerank_binding": args.rerank_binding if args.enable_rerank else None, diff --git a/lightrag/base.py b/lightrag/base.py index 9ba34280..cfe48eea 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -22,7 +22,6 @@ from .constants import ( DEFAULT_MAX_RELATION_TOKENS, DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_HISTORY_TURNS, - DEFAULT_ENABLE_RERANK, DEFAULT_OLLAMA_MODEL_NAME, DEFAULT_OLLAMA_MODEL_TAG, DEFAULT_OLLAMA_MODEL_SIZE, @@ -158,9 +157,7 @@ class QueryParam: If proivded, this will be use instead of the default vaulue from prompt template. """ - enable_rerank: bool = ( - os.getenv("ENABLE_RERANK", str(DEFAULT_ENABLE_RERANK).lower()).lower() == "true" - ) + enable_rerank: bool = os.getenv("ENABLE_RERANK", "false").lower() == "true" """Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True to enable reranking when rerank model is available. """ diff --git a/lightrag/constants.py b/lightrag/constants.py index d9ab121d..6aa845af 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -32,8 +32,8 @@ DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR" DEFAULT_HISTORY_TURNS = 0 # Rerank configuration defaults -DEFAULT_ENABLE_RERANK = True DEFAULT_MIN_RERANK_SCORE = 0.0 +DEFAULT_RERANK_BINDING = "cohere" # File path configuration for vector and graph database(Should not be changed, used in Milvus Schema) DEFAULT_MAX_FILE_PATH_LENGTH = 32768 diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8b214c16..721181d5 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -525,14 +525,6 @@ class LightRAG: ) ) - # Init Rerank - if self.rerank_model_func: - logger.info("Rerank model initialized for improved retrieval quality") - else: - logger.warning( - "Rerank is enabled but no rerank_model_func provided. Reranking will be skipped." - ) - self._storages_status = StoragesStatus.CREATED async def initialize_storages(self): diff --git a/lightrag/llm.py b/lightrag/llm.py deleted file mode 100644 index e5f98cf8..00000000 --- a/lightrag/llm.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Any -from pydantic import BaseModel, Field - - -class Model(BaseModel): - """ - This is a Pydantic model class named 'Model' that is used to define a custom language model. - - Attributes: - gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. - The function should take any argument and return a string. - kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. - This could include parameters such as the model name, API key, etc. - - Example usage: - Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) - - In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. - The 'kwargs' dictionary contains the model name and API key to be passed to the function. - """ - - gen_func: Callable[[Any], str] = Field( - ..., - description="A function that generates the response from the llm. The response must be a string", - ) - kwargs: dict[str, Any] = Field( - ..., - description="The arguments to pass to the callable function. Eg. the api key, model name, etc", - ) - - class Config: - arbitrary_types_allowed = True - - -class MultiModel: - """ - Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. - Could also be used for spliting across diffrent models or providers. - - Attributes: - models (List[Model]): A list of language models to be used. - - Usage example: - ```python - models = [ - Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), - Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), - Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), - Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), - Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), - ] - multi_model = MultiModel(models) - rag = LightRAG( - llm_model_func=multi_model.llm_model_func - / ..other args - ) - ``` - """ - - def __init__(self, models: list[Model]): - self._models = models - self._current_model = 0 - - def _next_model(self): - self._current_model = (self._current_model + 1) % len(self._models) - return self._models[self._current_model] - - async def llm_model_func( - self, - prompt: str, - system_prompt: str | None = None, - history_messages: list[dict[str, Any]] = [], - **kwargs: Any, - ) -> str: - kwargs.pop("model", None) # stop from overwriting the custom model name - kwargs.pop("keyword_extraction", None) - kwargs.pop("mode", None) - next_model = self._next_model() - args = dict( - prompt=prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - **next_model.kwargs, - ) - - return await next_model.gen_func(**args) - - -if __name__ == "__main__": - import asyncio - - async def main(): - from lightrag.llm.openai import gpt_4o_mini_complete - - result = await gpt_4o_mini_complete("How are you?") - print(result) - - asyncio.run(main()) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index dbac1098..35551f5a 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -32,7 +32,7 @@ async def generic_rerank_api( documents: List[str], model: str, base_url: str, - api_key: str, + api_key: Optional[str], top_n: Optional[int] = None, return_documents: Optional[bool] = None, extra_body: Optional[Dict[str, Any]] = None, @@ -56,13 +56,12 @@ async def generic_rerank_api( Returns: List of dictionary of ["index": int, "relevance_score": float] """ - if not api_key: - raise ValueError("API key is required") + if not base_url: + raise ValueError("Base URL is required") - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" # Build request payload based on request format if request_format == "aliyun": @@ -119,7 +118,6 @@ async def generic_rerank_api( error_text.strip().startswith("") or "text/html" in content_type ) - if is_html_error: if response.status == 502: clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes." @@ -131,7 +129,6 @@ async def generic_rerank_api( clean_error = f"HTTP {response.status} - Rerank service error. Please try again later." else: clean_error = error_text - logger.error(f"Rerank API error {response.status}: {clean_error}") raise aiohttp.ClientResponseError( request_info=response.request_info, @@ -142,17 +139,25 @@ async def generic_rerank_api( response_json = await response.json() - # Handle different response formats if response_format == "aliyun": # Aliyun format: {"output": {"results": [...]}} - output = response_json.get("output", {}) - results = output.get("results", []) + results = response_json.get("output", {}).get("results", []) + if not isinstance(results, list): + logger.warning( + f"Expected 'output.results' to be list, got {type(results)}: {results}" + ) + results = [] + elif response_format == "standard": # Standard format: {"results": [...]} results = response_json.get("results", []) + if not isinstance(results, list): + logger.warning( + f"Expected 'results' to be list, got {type(results)}: {results}" + ) + results = [] else: raise ValueError(f"Unsupported response format: {response_format}") - if not results: logger.warning("Rerank API returned empty results") return [] @@ -170,7 +175,7 @@ async def cohere_rerank( top_n: Optional[int] = None, api_key: Optional[str] = None, model: str = "rerank-v3.5", - base_url: str = "https://ai.znipower.com:5017/rerank", + base_url: str = "https://api.cohere.com/v2/rerank", extra_body: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ diff --git a/lightrag/utils.py b/lightrag/utils.py index 65e22d1a..2d3d485b 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1995,7 +1995,7 @@ async def apply_rerank_if_enabled( rerank_results = await rerank_func( query=query, documents=document_texts, - top_n=top_n or len(retrieved_docs), + top_n=top_n, ) # Process rerank results based on return format