Add rerank to server

This commit is contained in:
zrguo 2025-07-08 21:44:20 +08:00
parent cf26e52d89
commit d4651d59c1
4 changed files with 82 additions and 3 deletions

View file

@ -46,8 +46,19 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# HISTORY_TURNS=3
# COSINE_THRESHOLD=0.2
# TOP_K=60
### Number of text chunks to retrieve initially from vector search
# CHUNK_TOP_K=5
### Rerank Configuration
### Enable rerank functionality to improve retrieval quality
# ENABLE_RERANK=False
### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K)
# CHUNK_RERANK_TOP_K=5
### Rerank model configuration (required when ENABLE_RERANK=True)
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
# RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
# MAX_TOKEN_TEXT_CHUNK=6000
# MAX_TOKEN_RELATION_DESC=4000
# MAX_TOKEN_ENTITY_DESC=4000
@ -181,6 +192,3 @@ QDRANT_URL=http://localhost:6333
### Redis
REDIS_URI=redis://localhost:6379
# REDIS_WORKSPACE=forced_workspace_name
# Rerank Configuration
ENABLE_RERANK=False

View file

@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace:
default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--chunk-top-k",
type=int,
default=get_env_value("CHUNK_TOP_K", 5, int),
help="Number of text chunks to retrieve initially from vector search (default: from env or 5)",
)
parser.add_argument(
"--chunk-rerank-top-k",
type=int,
default=get_env_value("CHUNK_RERANK_TOP_K", 5, int),
help="Number of text chunks to keep after reranking (default: from env or 5)",
)
parser.add_argument(
"--enable-rerank",
action="store_true",
default=get_env_value("ENABLE_RERANK", False, bool),
help="Enable rerank functionality (default: from env or False)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace:
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Rerank model configuration
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args

View file

@ -291,6 +291,32 @@ def create_app(args):
),
)
# Configure rerank function if enabled
rerank_model_func = None
if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host:
from lightrag.rerank import custom_rerank
async def server_rerank_func(
query: str, documents: list, top_k: int = None, **kwargs
):
"""Server rerank function with configuration from environment variables"""
return await custom_rerank(
query=query,
documents=documents,
model=args.rerank_model,
base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key,
top_k=top_k,
**kwargs,
)
rerank_model_func = server_rerank_func
logger.info(f"Rerank enabled with model: {args.rerank_model}")
elif args.enable_rerank:
logger.warning(
"Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled."
)
# Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai"]:
rag = LightRAG(
@ -324,6 +350,8 @@ def create_app(args):
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
enable_rerank=args.enable_rerank,
rerank_model_func=rerank_model_func,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
@ -352,6 +380,8 @@ def create_app(args):
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
enable_rerank=args.enable_rerank,
rerank_model_func=rerank_model_func,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
@ -478,6 +508,12 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": args.enable_rerank,
"rerank_model": args.rerank_model if args.enable_rerank else None,
"rerank_binding_host": args.rerank_binding_host
if args.enable_rerank
else None,
},
"auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False),

View file

@ -49,6 +49,18 @@ class QueryRequest(BaseModel):
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
)
chunk_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to retrieve initially from vector search.",
)
chunk_rerank_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to keep after reranking.",
)
max_token_for_text_unit: Optional[int] = Field(
gt=1,
default=None,