From d4651d59c13d3ff75f203f145219ef9331c89338 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 21:44:20 +0800 Subject: [PATCH] Add rerank to server --- env.example | 14 ++++++++--- lightrag/api/config.py | 23 ++++++++++++++++++ lightrag/api/lightrag_server.py | 36 ++++++++++++++++++++++++++++ lightrag/api/routers/query_routes.py | 12 ++++++++++ 4 files changed, 82 insertions(+), 3 deletions(-) diff --git a/env.example b/env.example index 4447f5f0..874ebf3c 100644 --- a/env.example +++ b/env.example @@ -46,8 +46,19 @@ OLLAMA_EMULATING_MODEL_TAG=latest # HISTORY_TURNS=3 # COSINE_THRESHOLD=0.2 # TOP_K=60 +### Number of text chunks to retrieve initially from vector search # CHUNK_TOP_K=5 + +### Rerank Configuration +### Enable rerank functionality to improve retrieval quality +# ENABLE_RERANK=False +### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K) # CHUNK_RERANK_TOP_K=5 +### Rerank model configuration (required when ENABLE_RERANK=True) +# RERANK_MODEL=BAAI/bge-reranker-v2-m3 +# RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + # MAX_TOKEN_TEXT_CHUNK=6000 # MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000 @@ -181,6 +192,3 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name - -# Rerank Configuration -ENABLE_RERANK=False diff --git a/lightrag/api/config.py b/lightrag/api/config.py index ad0e670b..8c3fbff4 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace: default=get_env_value("TOP_K", 60, int), help="Number of most similar results to return (default: from env or 60)", ) + parser.add_argument( + "--chunk-top-k", + type=int, + default=get_env_value("CHUNK_TOP_K", 5, int), + help="Number of text chunks to retrieve initially from vector search (default: from env or 5)", + ) + parser.add_argument( + "--chunk-rerank-top-k", + type=int, + default=get_env_value("CHUNK_RERANK_TOP_K", 5, int), + help="Number of text chunks to keep after reranking (default: from env or 5)", + ) + parser.add_argument( + "--enable-rerank", + action="store_true", + default=get_env_value("ENABLE_RERANK", False, bool), + help="Enable rerank functionality (default: from env or False)", + ) parser.add_argument( "--cosine-threshold", type=float, @@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace: args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int) args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") + # Rerank model configuration + args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") + args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) + args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index cd87af22..b43c66d9 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -291,6 +291,32 @@ def create_app(args): ), ) + # Configure rerank function if enabled + rerank_model_func = None + if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host: + from lightrag.rerank import custom_rerank + + async def server_rerank_func( + query: str, documents: list, top_k: int = None, **kwargs + ): + """Server rerank function with configuration from environment variables""" + return await custom_rerank( + query=query, + documents=documents, + model=args.rerank_model, + base_url=args.rerank_binding_host, + api_key=args.rerank_binding_api_key, + top_k=top_k, + **kwargs, + ) + + rerank_model_func = server_rerank_func + logger.info(f"Rerank enabled with model: {args.rerank_model}") + elif args.enable_rerank: + logger.warning( + "Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled." + ) + # Initialize RAG if args.llm_binding in ["lollms", "ollama", "openai"]: rag = LightRAG( @@ -324,6 +350,8 @@ def create_app(args): }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, + enable_rerank=args.enable_rerank, + rerank_model_func=rerank_model_func, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, @@ -352,6 +380,8 @@ def create_app(args): }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, + enable_rerank=args.enable_rerank, + rerank_model_func=rerank_model_func, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, @@ -478,6 +508,12 @@ def create_app(args): "enable_llm_cache": args.enable_llm_cache, "workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, + # Rerank configuration + "enable_rerank": args.enable_rerank, + "rerank_model": args.rerank_model if args.enable_rerank else None, + "rerank_binding_host": args.rerank_binding_host + if args.enable_rerank + else None, }, "auth_mode": auth_mode, "pipeline_busy": pipeline_status.get("busy", False), diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 69aa32d8..0a0c6227 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -49,6 +49,18 @@ class QueryRequest(BaseModel): description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", ) + chunk_top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of text chunks to retrieve initially from vector search.", + ) + + chunk_rerank_top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of text chunks to keep after reranking.", + ) + max_token_for_text_unit: Optional[int] = Field( gt=1, default=None,