Add rerank to server
This commit is contained in:
parent
cf26e52d89
commit
d4651d59c1
4 changed files with 82 additions and 3 deletions
14
env.example
14
env.example
|
|
@ -46,8 +46,19 @@ OLLAMA_EMULATING_MODEL_TAG=latest
|
|||
# HISTORY_TURNS=3
|
||||
# COSINE_THRESHOLD=0.2
|
||||
# TOP_K=60
|
||||
### Number of text chunks to retrieve initially from vector search
|
||||
# CHUNK_TOP_K=5
|
||||
|
||||
### Rerank Configuration
|
||||
### Enable rerank functionality to improve retrieval quality
|
||||
# ENABLE_RERANK=False
|
||||
### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K)
|
||||
# CHUNK_RERANK_TOP_K=5
|
||||
### Rerank model configuration (required when ENABLE_RERANK=True)
|
||||
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||
# RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank
|
||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||
|
||||
# MAX_TOKEN_TEXT_CHUNK=6000
|
||||
# MAX_TOKEN_RELATION_DESC=4000
|
||||
# MAX_TOKEN_ENTITY_DESC=4000
|
||||
|
|
@ -181,6 +192,3 @@ QDRANT_URL=http://localhost:6333
|
|||
### Redis
|
||||
REDIS_URI=redis://localhost:6379
|
||||
# REDIS_WORKSPACE=forced_workspace_name
|
||||
|
||||
# Rerank Configuration
|
||||
ENABLE_RERANK=False
|
||||
|
|
|
|||
|
|
@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace:
|
|||
default=get_env_value("TOP_K", 60, int),
|
||||
help="Number of most similar results to return (default: from env or 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-top-k",
|
||||
type=int,
|
||||
default=get_env_value("CHUNK_TOP_K", 5, int),
|
||||
help="Number of text chunks to retrieve initially from vector search (default: from env or 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-rerank-top-k",
|
||||
type=int,
|
||||
default=get_env_value("CHUNK_RERANK_TOP_K", 5, int),
|
||||
help="Number of text chunks to keep after reranking (default: from env or 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-rerank",
|
||||
action="store_true",
|
||||
default=get_env_value("ENABLE_RERANK", False, bool),
|
||||
help="Enable rerank functionality (default: from env or False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cosine-threshold",
|
||||
type=float,
|
||||
|
|
@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace:
|
|||
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
|
||||
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
||||
|
||||
# Rerank model configuration
|
||||
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
||||
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
||||
|
||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||
|
||||
return args
|
||||
|
|
|
|||
|
|
@ -291,6 +291,32 @@ def create_app(args):
|
|||
),
|
||||
)
|
||||
|
||||
# Configure rerank function if enabled
|
||||
rerank_model_func = None
|
||||
if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host:
|
||||
from lightrag.rerank import custom_rerank
|
||||
|
||||
async def server_rerank_func(
|
||||
query: str, documents: list, top_k: int = None, **kwargs
|
||||
):
|
||||
"""Server rerank function with configuration from environment variables"""
|
||||
return await custom_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model=args.rerank_model,
|
||||
base_url=args.rerank_binding_host,
|
||||
api_key=args.rerank_binding_api_key,
|
||||
top_k=top_k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
rerank_model_func = server_rerank_func
|
||||
logger.info(f"Rerank enabled with model: {args.rerank_model}")
|
||||
elif args.enable_rerank:
|
||||
logger.warning(
|
||||
"Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled."
|
||||
)
|
||||
|
||||
# Initialize RAG
|
||||
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
||||
rag = LightRAG(
|
||||
|
|
@ -324,6 +350,8 @@ def create_app(args):
|
|||
},
|
||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||
enable_llm_cache=args.enable_llm_cache,
|
||||
enable_rerank=args.enable_rerank,
|
||||
rerank_model_func=rerank_model_func,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
|
|
@ -352,6 +380,8 @@ def create_app(args):
|
|||
},
|
||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||
enable_llm_cache=args.enable_llm_cache,
|
||||
enable_rerank=args.enable_rerank,
|
||||
rerank_model_func=rerank_model_func,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
|
|
@ -478,6 +508,12 @@ def create_app(args):
|
|||
"enable_llm_cache": args.enable_llm_cache,
|
||||
"workspace": args.workspace,
|
||||
"max_graph_nodes": args.max_graph_nodes,
|
||||
# Rerank configuration
|
||||
"enable_rerank": args.enable_rerank,
|
||||
"rerank_model": args.rerank_model if args.enable_rerank else None,
|
||||
"rerank_binding_host": args.rerank_binding_host
|
||||
if args.enable_rerank
|
||||
else None,
|
||||
},
|
||||
"auth_mode": auth_mode,
|
||||
"pipeline_busy": pipeline_status.get("busy", False),
|
||||
|
|
|
|||
|
|
@ -49,6 +49,18 @@ class QueryRequest(BaseModel):
|
|||
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
|
||||
)
|
||||
|
||||
chunk_top_k: Optional[int] = Field(
|
||||
ge=1,
|
||||
default=None,
|
||||
description="Number of text chunks to retrieve initially from vector search.",
|
||||
)
|
||||
|
||||
chunk_rerank_top_k: Optional[int] = Field(
|
||||
ge=1,
|
||||
default=None,
|
||||
description="Number of text chunks to keep after reranking.",
|
||||
)
|
||||
|
||||
max_token_for_text_unit: Optional[int] = Field(
|
||||
gt=1,
|
||||
default=None,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue