Fix: rename rerank parameter from top_k to top_n

The change aligns with the API parameter naming used by Jina and Cohere rerank services, ensuring consistency and clarity.
This commit is contained in:
yangdx 2025-07-20 00:26:27 +08:00
parent 4d8eda5ce3
commit cb3bf3291c
4 changed files with 31 additions and 31 deletions

View file

@ -57,7 +57,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
) )
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
"""Custom rerank function with all settings included""" """Custom rerank function with all settings included"""
return await custom_rerank( return await custom_rerank(
query=query, query=query,
@ -65,7 +65,7 @@ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwarg
model="BAAI/bge-reranker-v2-m3", model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank", base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here", api_key="your_rerank_api_key_here",
top_k=top_k or 10, # Default top_k if not provided top_n=top_n or 10,
**kwargs, **kwargs,
) )
@ -217,7 +217,7 @@ async def test_direct_rerank():
model="BAAI/bge-reranker-v2-m3", model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank", base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here", api_key="your_rerank_api_key_here",
top_k=3, top_n=3,
) )
print("\n✅ Rerank Results:") print("\n✅ Rerank Results:")

View file

@ -298,7 +298,7 @@ def create_app(args):
from lightrag.rerank import custom_rerank from lightrag.rerank import custom_rerank
async def server_rerank_func( async def server_rerank_func(
query: str, documents: list, top_k: int = None, **kwargs query: str, documents: list, top_n: int = None, **kwargs
): ):
"""Server rerank function with configuration from environment variables""" """Server rerank function with configuration from environment variables"""
return await custom_rerank( return await custom_rerank(
@ -307,7 +307,7 @@ def create_app(args):
model=args.rerank_model, model=args.rerank_model,
base_url=args.rerank_binding_host, base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key, api_key=args.rerank_binding_api_key,
top_k=top_k, top_n=top_n,
**kwargs, **kwargs,
) )

View file

@ -3165,7 +3165,7 @@ async def apply_rerank_if_enabled(
retrieved_docs: list[dict], retrieved_docs: list[dict],
global_config: dict, global_config: dict,
enable_rerank: bool = True, enable_rerank: bool = True,
top_k: int = None, top_n: int = None,
) -> list[dict]: ) -> list[dict]:
""" """
Apply reranking to retrieved documents if rerank is enabled. Apply reranking to retrieved documents if rerank is enabled.
@ -3175,7 +3175,7 @@ async def apply_rerank_if_enabled(
retrieved_docs: List of retrieved documents retrieved_docs: List of retrieved documents
global_config: Global configuration containing rerank settings global_config: Global configuration containing rerank settings
enable_rerank: Whether to enable reranking from query parameter enable_rerank: Whether to enable reranking from query parameter
top_k: Number of top documents to return after reranking top_n: Number of top documents to return after reranking
Returns: Returns:
Reranked documents if rerank is enabled, otherwise original documents Reranked documents if rerank is enabled, otherwise original documents
@ -3192,18 +3192,18 @@ async def apply_rerank_if_enabled(
try: try:
logger.debug( logger.debug(
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}" f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_n}"
) )
# Apply reranking - let rerank_model_func handle top_k internally # Apply reranking - let rerank_model_func handle top_k internally
reranked_docs = await rerank_func( reranked_docs = await rerank_func(
query=query, query=query,
documents=retrieved_docs, documents=retrieved_docs,
top_k=top_k, top_n=top_n,
) )
if reranked_docs and len(reranked_docs) > 0: if reranked_docs and len(reranked_docs) > 0:
if len(reranked_docs) > top_k: if len(reranked_docs) > top_n:
reranked_docs = reranked_docs[:top_k] reranked_docs = reranked_docs[:top_n]
logger.info( logger.info(
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
) )
@ -3263,7 +3263,7 @@ async def process_chunks_unified(
retrieved_docs=unique_chunks, retrieved_docs=unique_chunks,
global_config=global_config, global_config=global_config,
enable_rerank=query_param.enable_rerank, enable_rerank=query_param.enable_rerank,
top_k=rerank_top_k, top_n=rerank_top_k,
) )
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")

View file

@ -41,13 +41,13 @@ class RerankModel(BaseModel):
Or define a custom function directly: Or define a custom function directly:
```python ```python
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await jina_rerank( return await jina_rerank(
query=query, query=query,
documents=documents, documents=documents,
model="BAAI/bge-reranker-v2-m3", model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here", api_key="your_api_key_here",
top_k=top_k or 10, top_n=top_n or 10,
**kwargs **kwargs
) )
@ -71,14 +71,14 @@ class RerankModel(BaseModel):
self, self,
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[Dict[str, Any]],
top_k: Optional[int] = None, top_n: Optional[int] = None,
**extra_kwargs, **extra_kwargs,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function.""" """Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs # Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs} kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func( return await self.rerank_func(
query=query, documents=documents, top_k=top_k, **kwargs query=query, documents=documents, top_n=top_n, **kwargs
) )
@ -98,7 +98,7 @@ class MultiRerankModel(BaseModel):
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[Dict[str, Any]],
mode: str = "default", mode: str = "default",
top_k: Optional[int] = None, top_n: Optional[int] = None,
**kwargs, **kwargs,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode.""" """Rerank using the appropriate model based on mode."""
@ -116,7 +116,7 @@ class MultiRerankModel(BaseModel):
logger.warning(f"No rerank model available for mode: {mode}") logger.warning(f"No rerank model available for mode: {mode}")
return documents return documents
return await model.rerank(query, documents, top_k, **kwargs) return await model.rerank(query, documents, top_n, **kwargs)
async def generic_rerank_api( async def generic_rerank_api(
@ -125,7 +125,7 @@ async def generic_rerank_api(
model: str, model: str,
base_url: str, base_url: str,
api_key: str, api_key: str,
top_k: Optional[int] = None, top_n: Optional[int] = None,
**kwargs, **kwargs,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@ -137,7 +137,7 @@ async def generic_rerank_api(
model: Model identifier model: Model identifier
base_url: API endpoint URL base_url: API endpoint URL
api_key: API authentication key api_key: API authentication key
top_k: Number of top results to return top_n: Number of top results to return
**kwargs: Additional API-specific parameters **kwargs: Additional API-specific parameters
Returns: Returns:
@ -165,8 +165,8 @@ async def generic_rerank_api(
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
if top_k is not None: if top_n is not None:
data["top_k"] = min(top_k, len(prepared_docs)) data["top_n"] = min(top_n, len(prepared_docs))
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -206,7 +206,7 @@ async def jina_rerank(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3", model: str = "BAAI/bge-reranker-v2-m3",
top_k: Optional[int] = None, top_n: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank", base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None, api_key: Optional[str] = None,
**kwargs, **kwargs,
@ -218,7 +218,7 @@ async def jina_rerank(
query: The search query query: The search query
documents: List of documents to rerank documents: List of documents to rerank
model: Jina rerank model name model: Jina rerank model name
top_k: Number of top results to return top_n: Number of top results to return
base_url: Jina API endpoint base_url: Jina API endpoint
api_key: Jina API key api_key: Jina API key
**kwargs: Additional parameters **kwargs: Additional parameters
@ -235,7 +235,7 @@ async def jina_rerank(
model=model, model=model,
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
top_k=top_k, top_n=top_n,
**kwargs, **kwargs,
) )
@ -244,7 +244,7 @@ async def cohere_rerank(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[Dict[str, Any]],
model: str = "rerank-english-v2.0", model: str = "rerank-english-v2.0",
top_k: Optional[int] = None, top_n: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank", base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None, api_key: Optional[str] = None,
**kwargs, **kwargs,
@ -256,7 +256,7 @@ async def cohere_rerank(
query: The search query query: The search query
documents: List of documents to rerank documents: List of documents to rerank
model: Cohere rerank model name model: Cohere rerank model name
top_k: Number of top results to return top_n: Number of top results to return
base_url: Cohere API endpoint base_url: Cohere API endpoint
api_key: Cohere API key api_key: Cohere API key
**kwargs: Additional parameters **kwargs: Additional parameters
@ -273,7 +273,7 @@ async def cohere_rerank(
model=model, model=model,
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
top_k=top_k, top_n=top_n,
**kwargs, **kwargs,
) )
@ -285,7 +285,7 @@ async def custom_rerank(
model: str, model: str,
base_url: str, base_url: str,
api_key: str, api_key: str,
top_k: Optional[int] = None, top_n: Optional[int] = None,
**kwargs, **kwargs,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@ -298,7 +298,7 @@ async def custom_rerank(
model=model, model=model,
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
top_k=top_k, top_n=top_n,
**kwargs, **kwargs,
) )
@ -317,7 +317,7 @@ if __name__ == "__main__":
query = "What is the capital of France?" query = "What is the capital of France?"
result = await jina_rerank( result = await jina_rerank(
query=query, documents=docs, top_k=2, api_key="your-api-key-here" query=query, documents=docs, top_n=2, api_key="your-api-key-here"
) )
print(result) print(result)