from __future__ import annotations import os import json import aiohttp import numpy as np from typing import Callable, Any, List, Dict, Optional from pydantic import BaseModel, Field from dataclasses import asdict from .utils import logger class RerankModel(BaseModel): """ Pydantic model class for defining a custom rerank model. Attributes: rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents. The function should take query and documents as input and return reranked results. kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. This could include parameters such as the model name, API key, etc. Example usage: Rerank model example from jina: ```python rerank_model = RerankModel( rerank_func=jina_rerank, kwargs={ "model": "BAAI/bge-reranker-v2-m3", "api_key": "your_api_key_here", "base_url": "https://api.jina.ai/v1/rerank" } ) ``` """ rerank_func: Callable[[Any], List[Dict]] kwargs: Dict[str, Any] = Field(default_factory=dict) async def rerank( self, query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None, **extra_kwargs ) -> List[Dict[str, Any]]: """Rerank documents using the configured model function.""" # Merge extra kwargs with model kwargs kwargs = {**self.kwargs, **extra_kwargs} return await self.rerank_func( query=query, documents=documents, top_k=top_k, **kwargs ) class MultiRerankModel(BaseModel): """Multiple rerank models for different modes/scenarios.""" # Primary rerank model (used if mode-specific models are not defined) rerank_model: Optional[RerankModel] = None # Mode-specific rerank models entity_rerank_model: Optional[RerankModel] = None relation_rerank_model: Optional[RerankModel] = None chunk_rerank_model: Optional[RerankModel] = None async def rerank( self, query: str, documents: List[Dict[str, Any]], mode: str = "default", top_k: Optional[int] = None, **kwargs ) -> List[Dict[str, Any]]: """Rerank using the appropriate model based on mode.""" # Select model based on mode if mode == "entity" and self.entity_rerank_model: model = self.entity_rerank_model elif mode == "relation" and self.relation_rerank_model: model = self.relation_rerank_model elif mode == "chunk" and self.chunk_rerank_model: model = self.chunk_rerank_model elif self.rerank_model: model = self.rerank_model else: logger.warning(f"No rerank model available for mode: {mode}") return documents return await model.rerank(query, documents, top_k, **kwargs) async def generic_rerank_api( query: str, documents: List[Dict[str, Any]], model: str, base_url: str, api_key: str, top_k: Optional[int] = None, **kwargs ) -> List[Dict[str, Any]]: """ Generic rerank function that works with Jina/Cohere compatible APIs. Args: query: The search query documents: List of documents to rerank model: Model identifier base_url: API endpoint URL api_key: API authentication key top_k: Number of top results to return **kwargs: Additional API-specific parameters Returns: List of reranked documents with relevance scores """ if not api_key: logger.warning("No API key provided for rerank service") return documents if not documents: return documents # Prepare documents for reranking - handle both text and dict formats prepared_docs = [] for doc in documents: if isinstance(doc, dict): # Use 'content' field if available, otherwise use 'text' or convert to string text = doc.get('content') or doc.get('text') or str(doc) else: text = str(doc) prepared_docs.append(text) # Prepare request headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } data = { "model": model, "query": query, "documents": prepared_docs, **kwargs } if top_k is not None: data["top_k"] = min(top_k, len(prepared_docs)) try: async with aiohttp.ClientSession() as session: async with session.post(base_url, headers=headers, json=data) as response: if response.status != 200: error_text = await response.text() logger.error(f"Rerank API error {response.status}: {error_text}") return documents result = await response.json() # Extract reranked results if "results" in result: # Standard format: results contain index and relevance_score reranked_docs = [] for item in result["results"]: if "index" in item: doc_idx = item["index"] if 0 <= doc_idx < len(documents): reranked_doc = documents[doc_idx].copy() if "relevance_score" in item: reranked_doc["rerank_score"] = item["relevance_score"] reranked_docs.append(reranked_doc) return reranked_docs else: logger.warning("Unexpected rerank API response format") return documents except Exception as e: logger.error(f"Error during reranking: {e}") return documents async def jina_rerank( query: str, documents: List[Dict[str, Any]], model: str = "BAAI/bge-reranker-v2-m3", top_k: Optional[int] = None, base_url: str = "https://api.jina.ai/v1/rerank", api_key: Optional[str] = None, **kwargs ) -> List[Dict[str, Any]]: """ Rerank documents using Jina AI API. Args: query: The search query documents: List of documents to rerank model: Jina rerank model name top_k: Number of top results to return base_url: Jina API endpoint api_key: Jina API key **kwargs: Additional parameters Returns: List of reranked documents with relevance scores """ if api_key is None: api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY") return await generic_rerank_api( query=query, documents=documents, model=model, base_url=base_url, api_key=api_key, top_k=top_k, **kwargs ) async def cohere_rerank( query: str, documents: List[Dict[str, Any]], model: str = "rerank-english-v2.0", top_k: Optional[int] = None, base_url: str = "https://api.cohere.ai/v1/rerank", api_key: Optional[str] = None, **kwargs ) -> List[Dict[str, Any]]: """ Rerank documents using Cohere API. Args: query: The search query documents: List of documents to rerank model: Cohere rerank model name top_k: Number of top results to return base_url: Cohere API endpoint api_key: Cohere API key **kwargs: Additional parameters Returns: List of reranked documents with relevance scores """ if api_key is None: api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") return await generic_rerank_api( query=query, documents=documents, model=model, base_url=base_url, api_key=api_key, top_k=top_k, **kwargs ) # Convenience function for custom API endpoints async def custom_rerank( query: str, documents: List[Dict[str, Any]], model: str, base_url: str, api_key: str, top_k: Optional[int] = None, **kwargs ) -> List[Dict[str, Any]]: """ Rerank documents using a custom API endpoint. This is useful for self-hosted or custom rerank services. """ return await generic_rerank_api( query=query, documents=documents, model=model, base_url=base_url, api_key=api_key, top_k=top_k, **kwargs ) if __name__ == "__main__": import asyncio async def main(): # Example usage docs = [ {"content": "The capital of France is Paris."}, {"content": "Tokyo is the capital of Japan."}, {"content": "London is the capital of England."}, ] query = "What is the capital of France?" result = await jina_rerank( query=query, documents=docs, top_k=2, api_key="your-api-key-here" ) print(result) asyncio.run(main())