diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md index 647c0f91..f216a8c8 100644 --- a/docs/rerank_integration.md +++ b/docs/rerank_integration.md @@ -2,24 +2,15 @@ This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality. -## ⚠️ Important: Parameter Priority - -**QueryParam.top_k has higher priority than rerank_top_k configuration:** - -- When you set `QueryParam(top_k=5)`, it will override the `rerank_top_k=10` setting in LightRAG configuration -- This means the actual number of documents sent to rerank will be determined by QueryParam.top_k -- For optimal rerank performance, always consider the top_k value in your QueryParam calls -- Example: `rag.aquery(query, param=QueryParam(mode="naive", top_k=20))` will use 20, not rerank_top_k - ## Overview Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix). ## Architecture -The rerank integration follows the same design pattern as the LLM integration: +The rerank integration follows a simplified design pattern: -- **Configurable Models**: Support for multiple rerank providers through a generic API +- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function - **Async Processing**: Non-blocking rerank operations - **Error Handling**: Graceful fallback to original results - **Optional Feature**: Can be enabled/disabled via configuration @@ -29,24 +20,11 @@ The rerank integration follows the same design pattern as the LLM integration: ### Environment Variables -Set these variables in your `.env` file or environment: +Set this variable in your `.env` file or environment: ```bash # Enable/disable reranking ENABLE_RERANK=True - -# Rerank model configuration -RERANK_MODEL=BAAI/bge-reranker-v2-m3 -RERANK_MAX_ASYNC=4 -RERANK_TOP_K=10 - -# API configuration -RERANK_API_KEY=your_rerank_api_key_here -RERANK_BASE_URL=https://api.your-provider.com/v1/rerank - -# Provider-specific keys (optional alternatives) -JINA_API_KEY=your_jina_api_key_here -COHERE_API_KEY=your_cohere_api_key_here ``` ### Programmatic Configuration @@ -55,15 +33,27 @@ COHERE_API_KEY=your_cohere_api_key_here from lightrag import LightRAG from lightrag.rerank import custom_rerank, RerankModel -# Method 1: Using environment variables (recommended) +# Method 1: Using a custom rerank function with all settings included +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=top_k or 10, # Handle top_k within the function + **kwargs + ) + rag = LightRAG( working_dir="./rag_storage", llm_model_func=your_llm_func, embedding_func=your_embedding_func, - # Rerank automatically configured from environment variables + enable_rerank=True, + rerank_model_func=my_rerank_func, ) -# Method 2: Explicit configuration +# Method 2: Using RerankModel wrapper rerank_model = RerankModel( rerank_func=custom_rerank, kwargs={ @@ -79,7 +69,6 @@ rag = LightRAG( embedding_func=your_embedding_func, enable_rerank=True, rerank_model_func=rerank_model.rerank, - rerank_top_k=10, ) ``` @@ -112,7 +101,8 @@ result = await jina_rerank( query="your query", documents=documents, model="BAAI/bge-reranker-v2-m3", - api_key="your_jina_api_key" + api_key="your_jina_api_key", + top_k=10 ) ``` @@ -125,7 +115,8 @@ result = await cohere_rerank( query="your query", documents=documents, model="rerank-english-v2.0", - api_key="your_cohere_api_key" + api_key="your_cohere_api_key", + top_k=10 ) ``` @@ -143,11 +134,7 @@ Reranking is automatically applied at these key retrieval stages: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `enable_rerank` | bool | False | Enable/disable reranking | -| `rerank_model_name` | str | "BAAI/bge-reranker-v2-m3" | Model identifier | -| `rerank_model_max_async` | int | 4 | Max concurrent rerank calls | -| `rerank_top_k` | int | 10 | Number of top results to return ⚠️ **Overridden by QueryParam.top_k** | -| `rerank_model_func` | callable | None | Custom rerank function | -| `rerank_model_kwargs` | dict | {} | Additional rerank parameters | +| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_k, etc.) | ## Example Usage @@ -157,6 +144,18 @@ Reranking is automatically applied at these key retrieval stages: import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding +from lightrag.rerank import jina_rerank + +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + """Custom rerank function with all settings included""" + return await jina_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_jina_api_key_here", + top_k=top_k or 10, # Default top_k if not provided + **kwargs + ) async def main(): # Initialize with rerank enabled @@ -165,20 +164,21 @@ async def main(): llm_model_func=gpt_4o_mini_complete, embedding_func=openai_embedding, enable_rerank=True, + rerank_model_func=my_rerank_func, ) - + # Insert documents await rag.ainsert([ "Document 1 content...", "Document 2 content...", ]) - + # Query with rerank (automatically applied) result = await rag.aquery( "Your question here", - param=QueryParam(mode="hybrid", top_k=5) # ⚠️ This top_k=5 overrides rerank_top_k + param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function ) - + print(result) asyncio.run(main()) @@ -195,7 +195,7 @@ async def test_rerank(): {"content": "Text about topic B"}, {"content": "Text about topic C"}, ] - + reranked = await custom_rerank( query="Tell me about topic A", documents=documents, @@ -204,26 +204,26 @@ async def test_rerank(): api_key="your_api_key_here", top_k=2 ) - + for doc in reranked: print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}") ``` ## Best Practices -1. **Parameter Priority Awareness**: Remember that QueryParam.top_k always overrides rerank_top_k configuration +1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_k handling) within your rerank function 2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff -3. **API Limits**: Monitor API usage and implement rate limiting if needed +3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function 4. **Fallback**: Always handle rerank failures gracefully (returns original results) -5. **Top-k Selection**: Choose appropriate `top_k` values in QueryParam based on your use case +5. **Top-k Handling**: Handle top_k parameter appropriately within your rerank function 6. **Cost Management**: Consider rerank API costs in your budget planning ## Troubleshooting ### Common Issues -1. **API Key Missing**: Ensure `RERANK_API_KEY` or provider-specific keys are set -2. **Network Issues**: Check `RERANK_BASE_URL` and network connectivity +1. **API Key Missing**: Ensure API keys are properly configured within your rerank function +2. **Network Issues**: Check API endpoints and network connectivity 3. **Model Errors**: Verify the rerank model name is supported by your API 4. **Document Format**: Ensure documents have `content` or `text` fields @@ -268,4 +268,4 @@ The generic rerank API expects this response format: This is compatible with: - Jina AI Rerank API - Cohere Rerank API -- Custom APIs following the same format \ No newline at end of file +- Custom APIs following the same format diff --git a/env.example b/env.example index 49546343..c4a09cad 100644 --- a/env.example +++ b/env.example @@ -182,11 +182,3 @@ REDIS_URI=redis://localhost:6379 # Rerank Configuration ENABLE_RERANK=False -RERANK_MODEL=BAAI/bge-reranker-v2-m3 -RERANK_MAX_ASYNC=4 -RERANK_TOP_K=10 -# Note: QueryParam.top_k in your code will override RERANK_TOP_K setting - -# Rerank API Configuration -RERANK_API_KEY=your_rerank_api_key_here -RERANK_BASE_URL=https://api.your-provider.com/v1/rerank diff --git a/examples/rerank_example.py b/examples/rerank_example.py index 30ad794d..74ec85bc 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -4,19 +4,12 @@ LightRAG Rerank Integration Example This example demonstrates how to use rerank functionality with LightRAG to improve retrieval quality across different query modes. -IMPORTANT: Parameter Priority -- QueryParam(top_k=N) has higher priority than rerank_top_k in LightRAG configuration -- If you set QueryParam(top_k=5), it will override rerank_top_k setting -- For optimal rerank performance, use appropriate top_k values in QueryParam - Configuration Required: 1. Set your LLM API key and base URL in llm_model_func() -2. Set your embedding API key and base URL in embedding_func() +2. Set your embedding API key and base URL in embedding_func() 3. Set your rerank API key and base URL in the rerank configuration 4. Or use environment variables (.env file): - - RERANK_API_KEY=your_actual_rerank_api_key - - RERANK_BASE_URL=https://your-actual-rerank-endpoint/v1/rerank - - RERANK_MODEL=your_rerank_model_name + - ENABLE_RERANK=True """ import asyncio @@ -35,6 +28,7 @@ setup_logger("test_rerank") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -48,6 +42,7 @@ async def llm_model_func( **kwargs, ) + async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embed( texts, @@ -56,25 +51,29 @@ async def embedding_func(texts: list[str]) -> np.ndarray: base_url="https://api.your-embedding-provider.com/v1", ) + +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + """Custom rerank function with all settings included""" + return await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-rerank-provider.com/v1/rerank", + api_key="your_rerank_api_key_here", + top_k=top_k or 10, # Default top_k if not provided + **kwargs, + ) + + async def create_rag_with_rerank(): """Create LightRAG instance with rerank configuration""" - + # Get embedding dimension test_embedding = await embedding_func(["test"]) embedding_dim = test_embedding.shape[1] print(f"Detected embedding dimension: {embedding_dim}") - # Create rerank model - rerank_model = RerankModel( - rerank_func=custom_rerank, - kwargs={ - "model": "BAAI/bge-reranker-v2-m3", - "base_url": "https://api.your-rerank-provider.com/v1/rerank", - "api_key": "your_rerank_api_key_here", - } - ) - - # Initialize LightRAG with rerank + # Method 1: Using custom rerank function rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=llm_model_func, @@ -83,69 +82,100 @@ async def create_rag_with_rerank(): max_token_size=8192, func=embedding_func, ), - # Rerank Configuration + # Simplified Rerank Configuration enable_rerank=True, - rerank_model_func=rerank_model.rerank, - rerank_top_k=10, # Note: QueryParam.top_k will override this + rerank_model_func=my_rerank_func, ) return rag + +async def create_rag_with_rerank_model(): + """Alternative: Create LightRAG instance using RerankModel wrapper""" + + # Get embedding dimension + test_embedding = await embedding_func(["test"]) + embedding_dim = test_embedding.shape[1] + print(f"Detected embedding dimension: {embedding_dim}") + + # Method 2: Using RerankModel wrapper + rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-rerank-provider.com/v1/rerank", + "api_key": "your_rerank_api_key_here", + }, + ) + + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dim, + max_token_size=8192, + func=embedding_func, + ), + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + ) + + return rag + + async def test_rerank_with_different_topk(): """ - Test rerank functionality with different top_k settings to demonstrate parameter priority + Test rerank functionality with different top_k settings """ print("🚀 Setting up LightRAG with Rerank functionality...") - + rag = await create_rag_with_rerank() - + # Insert sample documents sample_docs = [ "Reranking improves retrieval quality by re-ordering documents based on relevance.", "LightRAG is a powerful retrieval-augmented generation system with multiple query modes.", "Vector databases enable efficient similarity search in high-dimensional embedding spaces.", "Natural language processing has evolved with large language models and transformers.", - "Machine learning algorithms can learn patterns from data without explicit programming." + "Machine learning algorithms can learn patterns from data without explicit programming.", ] - + print("📄 Inserting sample documents...") await rag.ainsert(sample_docs) - + query = "How does reranking improve retrieval quality?" print(f"\n🔍 Testing query: '{query}'") print("=" * 80) - + # Test different top_k values to show parameter priority top_k_values = [2, 5, 10] - + for top_k in top_k_values: - print(f"\n📊 Testing with QueryParam(top_k={top_k}) - overrides rerank_top_k=10:") - + print(f"\n📊 Testing with QueryParam(top_k={top_k}):") + # Test naive mode with specific top_k - result = await rag.aquery( - query, - param=QueryParam(mode="naive", top_k=top_k) - ) + result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k)) print(f" Result length: {len(result)} characters") print(f" Preview: {result[:100]}...") + async def test_direct_rerank(): """Test rerank function directly""" print("\n🔧 Direct Rerank API Test") print("=" * 40) - + documents = [ {"content": "Reranking significantly improves retrieval quality"}, {"content": "LightRAG supports advanced reranking capabilities"}, {"content": "Vector search finds semantically similar documents"}, {"content": "Natural language processing with modern transformers"}, - {"content": "The quick brown fox jumps over the lazy dog"} + {"content": "The quick brown fox jumps over the lazy dog"}, ] - + query = "rerank improve quality" print(f"Query: '{query}'") print(f"Documents: {len(documents)}") - + try: reranked_docs = await custom_rerank( query=query, @@ -153,41 +183,44 @@ async def test_direct_rerank(): model="BAAI/bge-reranker-v2-m3", base_url="https://api.your-rerank-provider.com/v1/rerank", api_key="your_rerank_api_key_here", - top_k=3 + top_k=3, ) - + print("\n✅ Rerank Results:") for i, doc in enumerate(reranked_docs): score = doc.get("rerank_score", "N/A") content = doc.get("content", "")[:60] print(f" {i+1}. Score: {score:.4f} | {content}...") - + except Exception as e: print(f"❌ Rerank failed: {e}") + async def main(): """Main example function""" print("🎯 LightRAG Rerank Integration Example") print("=" * 60) - + try: # Test rerank with different top_k values await test_rerank_with_different_topk() - + # Test direct rerank await test_direct_rerank() - + print("\n✅ Example completed successfully!") print("\n💡 Key Points:") - print(" ✓ QueryParam.top_k has higher priority than rerank_top_k") + print(" ✓ All rerank configurations are contained within rerank_model_func") print(" ✓ Rerank improves document relevance ordering") - print(" ✓ Configure API keys in your .env file for production") + print(" ✓ Configure API keys within your rerank function") print(" ✓ Monitor API usage and costs when using rerank services") - + except Exception as e: print(f"\n❌ Example failed: {e}") import traceback + traceback.print_exc() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cee08373..63a2f531 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -249,25 +249,7 @@ class LightRAG: """Enable reranking for improved retrieval quality. Defaults to False.""" rerank_model_func: Callable[..., object] | None = field(default=None) - """Function for reranking retrieved documents. Optional.""" - - rerank_model_name: str = field( - default=os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") - ) - """Name of the rerank model used for reranking documents.""" - - rerank_model_max_async: int = field(default=int(os.getenv("RERANK_MAX_ASYNC", 4))) - """Maximum number of concurrent rerank calls.""" - - rerank_model_kwargs: dict[str, Any] = field(default_factory=dict) - """Additional keyword arguments passed to the rerank model function.""" - - rerank_top_k: int = field(default=int(os.getenv("RERANK_TOP_K", 10))) - """Number of top documents to return after reranking. - - Note: This value will be overridden by QueryParam.top_k in query calls. - Example: QueryParam(top_k=5) will override rerank_top_k=10 setting. - """ + """Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional.""" # Storage # --- @@ -475,14 +457,6 @@ class LightRAG: # Init Rerank if self.enable_rerank and self.rerank_model_func: - self.rerank_model_func = priority_limit_async_func_call( - self.rerank_model_max_async - )( - partial( - self.rerank_model_func, # type: ignore - **self.rerank_model_kwargs, - ) - ) logger.info("Rerank model initialized for improved retrieval quality") elif self.enable_rerank and not self.rerank_model_func: logger.warning( diff --git a/lightrag/operate.py b/lightrag/operate.py index b5d74c55..645c1e85 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2864,19 +2864,15 @@ async def apply_rerank_if_enabled( return retrieved_docs try: - # Determine top_k for reranking - rerank_top_k = top_k or global_config.get("rerank_top_k", 10) - rerank_top_k = min(rerank_top_k, len(retrieved_docs)) - logger.debug( - f"Applying rerank to {len(retrieved_docs)} documents, returning top {rerank_top_k}" + f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}" ) - # Apply reranking + # Apply reranking - let rerank_model_func handle top_k internally reranked_docs = await rerank_func( query=query, documents=retrieved_docs, - top_k=rerank_top_k, + top_k=top_k, ) if reranked_docs and len(reranked_docs) > 0: @@ -2886,7 +2882,7 @@ async def apply_rerank_if_enabled( return reranked_docs else: logger.warning("Rerank returned empty results, using original documents") - return retrieved_docs[:rerank_top_k] if rerank_top_k else retrieved_docs + return retrieved_docs except Exception as e: logger.error(f"Error during reranking: {e}, using original documents") diff --git a/lightrag/rerank.py b/lightrag/rerank.py index d25a8485..59719bc9 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -1,12 +1,9 @@ from __future__ import annotations import os -import json import aiohttp -import numpy as np from typing import Callable, Any, List, Dict, Optional from pydantic import BaseModel, Field -from dataclasses import asdict from .utils import logger @@ -15,14 +12,17 @@ class RerankModel(BaseModel): """ Pydantic model class for defining a custom rerank model. + This class provides a convenient wrapper for rerank functions, allowing you to + encapsulate all rerank configurations (API keys, model settings, etc.) in one place. + Attributes: rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents. The function should take query and documents as input and return reranked results. kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. - This could include parameters such as the model name, API key, etc. + This should include all necessary configurations such as model name, API key, base_url, etc. Example usage: - Rerank model example from jina: + Rerank model example with Jina: ```python rerank_model = RerankModel( rerank_func=jina_rerank, @@ -32,6 +32,32 @@ class RerankModel(BaseModel): "base_url": "https://api.jina.ai/v1/rerank" } ) + + # Use in LightRAG + rag = LightRAG( + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + # ... other configurations + ) + ``` + + Or define a custom function directly: + ```python + async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await jina_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_api_key_here", + top_k=top_k or 10, + **kwargs + ) + + rag = LightRAG( + enable_rerank=True, + rerank_model_func=my_rerank_func, + # ... other configurations + ) ``` """ @@ -43,25 +69,22 @@ class RerankModel(BaseModel): query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None, - **extra_kwargs + **extra_kwargs, ) -> List[Dict[str, Any]]: """Rerank documents using the configured model function.""" # Merge extra kwargs with model kwargs kwargs = {**self.kwargs, **extra_kwargs} return await self.rerank_func( - query=query, - documents=documents, - top_k=top_k, - **kwargs + query=query, documents=documents, top_k=top_k, **kwargs ) class MultiRerankModel(BaseModel): """Multiple rerank models for different modes/scenarios.""" - + # Primary rerank model (used if mode-specific models are not defined) rerank_model: Optional[RerankModel] = None - + # Mode-specific rerank models entity_rerank_model: Optional[RerankModel] = None relation_rerank_model: Optional[RerankModel] = None @@ -73,10 +96,10 @@ class MultiRerankModel(BaseModel): documents: List[Dict[str, Any]], mode: str = "default", top_k: Optional[int] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """Rerank using the appropriate model based on mode.""" - + # Select model based on mode if mode == "entity" and self.entity_rerank_model: model = self.entity_rerank_model @@ -89,7 +112,7 @@ class MultiRerankModel(BaseModel): else: logger.warning(f"No rerank model available for mode: {mode}") return documents - + return await model.rerank(query, documents, top_k, **kwargs) @@ -100,11 +123,11 @@ async def generic_rerank_api( base_url: str, api_key: str, top_k: Optional[int] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Generic rerank function that works with Jina/Cohere compatible APIs. - + Args: query: The search query documents: List of documents to rerank @@ -113,43 +136,35 @@ async def generic_rerank_api( api_key: API authentication key top_k: Number of top results to return **kwargs: Additional API-specific parameters - + Returns: List of reranked documents with relevance scores """ if not api_key: logger.warning("No API key provided for rerank service") return documents - + if not documents: return documents - + # Prepare documents for reranking - handle both text and dict formats prepared_docs = [] for doc in documents: if isinstance(doc, dict): # Use 'content' field if available, otherwise use 'text' or convert to string - text = doc.get('content') or doc.get('text') or str(doc) + text = doc.get("content") or doc.get("text") or str(doc) else: text = str(doc) prepared_docs.append(text) - + # Prepare request - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - - data = { - "model": model, - "query": query, - "documents": prepared_docs, - **kwargs - } - + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + + data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} + if top_k is not None: data["top_k"] = min(top_k, len(prepared_docs)) - + try: async with aiohttp.ClientSession() as session: async with session.post(base_url, headers=headers, json=data) as response: @@ -157,9 +172,9 @@ async def generic_rerank_api( error_text = await response.text() logger.error(f"Rerank API error {response.status}: {error_text}") return documents - + result = await response.json() - + # Extract reranked results if "results" in result: # Standard format: results contain index and relevance_score @@ -170,13 +185,15 @@ async def generic_rerank_api( if 0 <= doc_idx < len(documents): reranked_doc = documents[doc_idx].copy() if "relevance_score" in item: - reranked_doc["rerank_score"] = item["relevance_score"] + reranked_doc["rerank_score"] = item[ + "relevance_score" + ] reranked_docs.append(reranked_doc) return reranked_docs else: logger.warning("Unexpected rerank API response format") return documents - + except Exception as e: logger.error(f"Error during reranking: {e}") return documents @@ -189,11 +206,11 @@ async def jina_rerank( top_k: Optional[int] = None, base_url: str = "https://api.jina.ai/v1/rerank", api_key: Optional[str] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Rerank documents using Jina AI API. - + Args: query: The search query documents: List of documents to rerank @@ -202,13 +219,13 @@ async def jina_rerank( base_url: Jina API endpoint api_key: Jina API key **kwargs: Additional parameters - + Returns: List of reranked documents with relevance scores """ if api_key is None: api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY") - + return await generic_rerank_api( query=query, documents=documents, @@ -216,7 +233,7 @@ async def jina_rerank( base_url=base_url, api_key=api_key, top_k=top_k, - **kwargs + **kwargs, ) @@ -227,11 +244,11 @@ async def cohere_rerank( top_k: Optional[int] = None, base_url: str = "https://api.cohere.ai/v1/rerank", api_key: Optional[str] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Rerank documents using Cohere API. - + Args: query: The search query documents: List of documents to rerank @@ -240,13 +257,13 @@ async def cohere_rerank( base_url: Cohere API endpoint api_key: Cohere API key **kwargs: Additional parameters - + Returns: List of reranked documents with relevance scores """ if api_key is None: api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") - + return await generic_rerank_api( query=query, documents=documents, @@ -254,7 +271,7 @@ async def cohere_rerank( base_url=base_url, api_key=api_key, top_k=top_k, - **kwargs + **kwargs, ) @@ -266,7 +283,7 @@ async def custom_rerank( base_url: str, api_key: str, top_k: Optional[int] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Rerank documents using a custom API endpoint. @@ -279,7 +296,7 @@ async def custom_rerank( base_url=base_url, api_key=api_key, top_k=top_k, - **kwargs + **kwargs, ) @@ -293,15 +310,12 @@ if __name__ == "__main__": {"content": "Tokyo is the capital of Japan."}, {"content": "London is the capital of England."}, ] - + query = "What is the capital of France?" - + result = await jina_rerank( - query=query, - documents=docs, - top_k=2, - api_key="your-api-key-here" + query=query, documents=docs, top_k=2, api_key="your-api-key-here" ) print(result) - asyncio.run(main()) \ No newline at end of file + asyncio.run(main())