LightRAG/lightrag/rerank.py
clssck da9070ecf7 refactor: remove legacy storage implementations and k8s deployment
Remove deprecated storage backends and Kubernetes deployment configuration:
- Delete unused storage implementations: FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis
- Remove Kubernetes deployment manifests and installation scripts
- Delete legacy examples for deprecated backends
- Consolidate to PostgreSQL-only storage backend
Streamline dependencies and add new capabilities:
- Remove deprecated code documentation and migration guides
- Add full-text search caching layer with FTS cache module
- Implement metrics collection and monitoring pipeline
- Add explain and metrics API routes
- Simplify configuration with PostgreSQL-focused setup
Update documentation and configuration:
- Rewrite README to focus on supported features
- Update environment and configuration examples
- Remove Kubernetes-specific documentation
- Add new utility scripts for PDF uploads and pipeline monitoring
2025-12-09 14:02:00 +01:00

209 lines
5.8 KiB
Python

"""
Local reranker using sentence-transformers CrossEncoder.
Uses mixedbread-ai/mxbai-rerank-xsmall-v1 by default - a small but effective
cross-encoder model that runs entirely locally without API calls.
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable, Sequence
import os
from typing import Protocol, SupportsFloat, TypedDict, runtime_checkable
from .utils import logger
# Global model cache to avoid reloading on every call
_reranker_model: RerankerModel | None = None
_reranker_model_name: str | None = None
# Default model - best quality/size tradeoff from benchmarks
DEFAULT_RERANK_MODEL = 'mixedbread-ai/mxbai-rerank-xsmall-v1'
class RerankResult(TypedDict):
index: int
relevance_score: float
@runtime_checkable
class SupportsToList(Protocol):
def tolist(self) -> list[float]: ...
ScoreLike = Sequence[SupportsFloat] | SupportsToList
@runtime_checkable
class RerankerModel(Protocol):
def predict(
self,
sentences: list[list[str]],
batch_size: int = ...,
) -> ScoreLike: ...
def get_reranker_model(model_name: str | None = None):
"""
Get or initialize the reranker model (cached).
Args:
model_name: HuggingFace model name. Defaults to mxbai-rerank-xsmall-v1
Returns:
CrossEncoder-like model instance implementing predict(pairs)->list[float]
"""
global _reranker_model, _reranker_model_name
if model_name is None:
model_name = os.getenv('RERANK_MODEL', DEFAULT_RERANK_MODEL)
# Return cached model if same name
if _reranker_model is not None and _reranker_model_name == model_name:
return _reranker_model
try:
from sentence_transformers import CrossEncoder
logger.info(f'Loading reranker model: {model_name}')
_reranker_model = CrossEncoder(model_name, trust_remote_code=True)
_reranker_model_name = model_name
logger.info(f'Reranker model loaded: {model_name}')
return _reranker_model
except ImportError as err:
raise ImportError(
'sentence-transformers is required for local reranking. Install with: pip install sentence-transformers'
) from err
except Exception as e:
logger.error(f'Failed to load reranker model {model_name}: {e}')
raise
async def local_rerank(
query: str,
documents: list[str],
top_n: int | None = None,
model_name: str | None = None,
) -> list[RerankResult]:
"""
Rerank documents using local CrossEncoder model.
Args:
query: The search query
documents: List of document strings to rerank
top_n: Number of top results to return (None = all)
model_name: HuggingFace model name (default: mxbai-rerank-xsmall-v1)
Returns:
List of dicts with 'index' and 'relevance_score', sorted by score descending
Example:
>>> results = await local_rerank(
... query="What is machine learning?",
... documents=["ML is a subset of AI...", "The weather is nice..."],
... top_n=5
... )
>>> print(results[0])
{'index': 0, 'relevance_score': 0.95}
"""
if not documents:
return []
model = get_reranker_model(model_name)
# Create query-document pairs
pairs = [[query, doc] for doc in documents]
# Get scores from model
# CrossEncoder.predict returns a list[float]; guard None for type checkers
if model is None:
raise RuntimeError('Reranker model failed to load')
raw_scores = model.predict(pairs)
# Normalize to a list[float] regardless of backend (list, numpy array, tensor)
if isinstance(raw_scores, SupportsToList):
raw_scores = raw_scores.tolist()
scores = [float(score) for score in raw_scores]
# Build results with index and score
results: list[RerankResult] = [
RerankResult(index=i, relevance_score=float(score)) for i, score in enumerate(scores)
]
# Sort by score descending
results.sort(key=lambda x: x['relevance_score'], reverse=True)
# Apply top_n limit if specified
if top_n is not None and top_n < len(results):
results = results[:top_n]
return results
def create_local_rerank_func(
model_name: str | None = None,
) -> Callable[..., Awaitable[list[RerankResult]]]:
"""
Create a rerank function with pre-configured model.
This is used by lightrag_server to create a rerank function
that can be passed to LightRAG initialization.
Args:
model_name: HuggingFace model name (default: mxbai-rerank-xsmall-v1)
Returns:
Async rerank function
"""
# Pre-load model to fail fast if there's an issue
get_reranker_model(model_name)
async def rerank_func(
query: str,
documents: list[str],
top_n: int | None = None,
**kwargs,
) -> list[RerankResult]:
return await local_rerank(
query=query,
documents=documents,
top_n=top_n,
model_name=model_name,
)
return rerank_func
# For backwards compatibility - alias to local_rerank
rerank = local_rerank
if __name__ == '__main__':
import asyncio
async def main():
docs = [
'The capital of France is Paris.',
'Tokyo is the capital of Japan.',
'London is the capital of England.',
'Python is a programming language.',
]
query = 'What is the capital of France?'
print('=== Local Reranker Test ===')
print(f'Model: {os.getenv("RERANK_MODEL", DEFAULT_RERANK_MODEL)}')
print(f'Query: {query}')
print()
results = await local_rerank(query=query, documents=docs, top_n=3)
print('Results (top 3):')
for item in results:
idx = item['index']
score = item['relevance_score']
print(f' [{idx}] Score: {score:.4f} - {docs[idx]}')
asyncio.run(main())