Refactored rerank_example file to utilize the updated rerank function.
This commit is contained in:
parent
9bc349ddd6
commit
3d5e6226a9
3 changed files with 53 additions and 85 deletions
|
|
@ -96,14 +96,14 @@ RERANK_BINDING=null
|
||||||
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
||||||
# MIN_RERANK_SCORE=0.0
|
# MIN_RERANK_SCORE=0.0
|
||||||
|
|
||||||
### For local deployment
|
### For local deployment with vLLM
|
||||||
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||||
# RERANK_BINDING_HOST=http://localhost:8000
|
# RERANK_BINDING_HOST=http://localhost:8000/v1/rerank
|
||||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
|
||||||
### Default value for Cohere AI
|
### Default value for Cohere AI
|
||||||
# RERANK_MODEL=rerank-v3.5
|
# RERANK_MODEL=rerank-v3.5
|
||||||
# RERANK_BINDING_HOST=https://ai.znipower.com:5017/rerank
|
# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank
|
||||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
|
||||||
### Default value for Jina AI
|
### Default value for Jina AI
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,21 @@ This example demonstrates how to use rerank functionality with LightRAG
|
||||||
to improve retrieval quality across different query modes.
|
to improve retrieval quality across different query modes.
|
||||||
|
|
||||||
Configuration Required:
|
Configuration Required:
|
||||||
1. Set your LLM API key and base URL in llm_model_func()
|
1. Set your OpenAI LLM API key and base URL with env vars
|
||||||
2. Set your embedding API key and base URL in embedding_func()
|
LLM_MODEL
|
||||||
3. Set your rerank API key and base URL in the rerank configuration
|
LLM_BINDING_HOST
|
||||||
4. Or use environment variables (.env file):
|
LLM_BINDING_API_KEY
|
||||||
- RERANK_MODEL=your_rerank_model
|
2. Set your OpenAI embedding API key and base URL with env vars:
|
||||||
- RERANK_BINDING_HOST=your_rerank_endpoint
|
EMBEDDING_MODEL
|
||||||
- RERANK_BINDING_API_KEY=your_rerank_api_key
|
EMBEDDING_DIM
|
||||||
|
EMBEDDING_BINDING_HOST
|
||||||
|
EMBEDDING_BINDING_API_KEY
|
||||||
|
3. Set your vLLM deployed AI rerank model setting with env vars:
|
||||||
|
RERANK_MODEL
|
||||||
|
RERANK_BINDING_HOST
|
||||||
|
RERANK_BINDING_API_KEY
|
||||||
|
|
||||||
Note: Rerank is now controlled per query via the 'enable_rerank' parameter (default: True)
|
Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -21,11 +27,13 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.rerank import custom_rerank, RerankModel
|
|
||||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||||
from lightrag.utils import EmbeddingFunc, setup_logger
|
from lightrag.utils import EmbeddingFunc, setup_logger
|
||||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from lightrag.rerank import cohere_rerank
|
||||||
|
|
||||||
# Set up your working directory
|
# Set up your working directory
|
||||||
WORKING_DIR = "./test_rerank"
|
WORKING_DIR = "./test_rerank"
|
||||||
setup_logger("test_rerank")
|
setup_logger("test_rerank")
|
||||||
|
|
@ -38,12 +46,12 @@ async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
"gpt-4o-mini",
|
os.getenv("LLM_MODEL"),
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
api_key="your_llm_api_key_here",
|
api_key=os.getenv("LLM_BINDING_API_KEY"),
|
||||||
base_url="https://api.your-llm-provider.com/v1",
|
base_url=os.getenv("LLM_BINDING_HOST"),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -51,23 +59,18 @@ async def llm_model_func(
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await openai_embed(
|
return await openai_embed(
|
||||||
texts,
|
texts,
|
||||||
model="text-embedding-3-large",
|
model=os.getenv("EMBEDDING_MODEL"),
|
||||||
api_key="your_embedding_api_key_here",
|
api_key=os.getenv("EMBEDDING_BINDING_API_KEY"),
|
||||||
base_url="https://api.your-embedding-provider.com/v1",
|
base_url=os.getenv("EMBEDDING_BINDING_HOST"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
rerank_model_func = partial(
|
||||||
"""Custom rerank function with all settings included"""
|
cohere_rerank,
|
||||||
return await custom_rerank(
|
model=os.getenv("RERANK_MODEL"),
|
||||||
query=query,
|
api_key=os.getenv("RERANK_BINDING_API_KEY"),
|
||||||
documents=documents,
|
base_url=os.getenv("RERANK_BINDING_HOST"),
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
)
|
||||||
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
|
||||||
api_key="your_rerank_api_key_here",
|
|
||||||
top_n=top_n or 10,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_rag_with_rerank():
|
async def create_rag_with_rerank():
|
||||||
|
|
@ -88,42 +91,7 @@ async def create_rag_with_rerank():
|
||||||
func=embedding_func,
|
func=embedding_func,
|
||||||
),
|
),
|
||||||
# Rerank Configuration - provide the rerank function
|
# Rerank Configuration - provide the rerank function
|
||||||
rerank_model_func=my_rerank_func,
|
rerank_model_func=rerank_model_func,
|
||||||
)
|
|
||||||
|
|
||||||
await rag.initialize_storages()
|
|
||||||
await initialize_pipeline_status()
|
|
||||||
|
|
||||||
return rag
|
|
||||||
|
|
||||||
|
|
||||||
async def create_rag_with_rerank_model():
|
|
||||||
"""Alternative: Create LightRAG instance using RerankModel wrapper"""
|
|
||||||
|
|
||||||
# Get embedding dimension
|
|
||||||
test_embedding = await embedding_func(["test"])
|
|
||||||
embedding_dim = test_embedding.shape[1]
|
|
||||||
print(f"Detected embedding dimension: {embedding_dim}")
|
|
||||||
|
|
||||||
# Method 2: Using RerankModel wrapper
|
|
||||||
rerank_model = RerankModel(
|
|
||||||
rerank_func=custom_rerank,
|
|
||||||
kwargs={
|
|
||||||
"model": "BAAI/bge-reranker-v2-m3",
|
|
||||||
"base_url": "https://api.your-rerank-provider.com/v1/rerank",
|
|
||||||
"api_key": "your_rerank_api_key_here",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
rag = LightRAG(
|
|
||||||
working_dir=WORKING_DIR,
|
|
||||||
llm_model_func=llm_model_func,
|
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
max_token_size=8192,
|
|
||||||
func=embedding_func,
|
|
||||||
),
|
|
||||||
rerank_model_func=rerank_model.rerank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await rag.initialize_storages()
|
await rag.initialize_storages()
|
||||||
|
|
@ -136,7 +104,7 @@ async def test_rerank_with_different_settings():
|
||||||
"""
|
"""
|
||||||
Test rerank functionality with different enable_rerank settings
|
Test rerank functionality with different enable_rerank settings
|
||||||
"""
|
"""
|
||||||
print("🚀 Setting up LightRAG with Rerank functionality...")
|
print("\n\n🚀 Setting up LightRAG with Rerank functionality...")
|
||||||
|
|
||||||
rag = await create_rag_with_rerank()
|
rag = await create_rag_with_rerank()
|
||||||
|
|
||||||
|
|
@ -199,11 +167,11 @@ async def test_direct_rerank():
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
{"content": "Reranking significantly improves retrieval quality"},
|
"Vector search finds semantically similar documents",
|
||||||
{"content": "LightRAG supports advanced reranking capabilities"},
|
"LightRAG supports advanced reranking capabilities",
|
||||||
{"content": "Vector search finds semantically similar documents"},
|
"Reranking significantly improves retrieval quality",
|
||||||
{"content": "Natural language processing with modern transformers"},
|
"Natural language processing with modern transformers",
|
||||||
{"content": "The quick brown fox jumps over the lazy dog"},
|
"The quick brown fox jumps over the lazy dog",
|
||||||
]
|
]
|
||||||
|
|
||||||
query = "rerank improve quality"
|
query = "rerank improve quality"
|
||||||
|
|
@ -211,20 +179,20 @@ async def test_direct_rerank():
|
||||||
print(f"Documents: {len(documents)}")
|
print(f"Documents: {len(documents)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reranked_docs = await custom_rerank(
|
reranked_results = await rerank_model_func(
|
||||||
query=query,
|
query=query,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
top_n=4,
|
||||||
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
|
||||||
api_key="your_rerank_api_key_here",
|
|
||||||
top_n=3,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n✅ Rerank Results:")
|
print("\n✅ Rerank Results:")
|
||||||
for i, doc in enumerate(reranked_docs):
|
i = 0
|
||||||
score = doc.get("rerank_score", "N/A")
|
for result in reranked_results:
|
||||||
content = doc.get("content", "")[:60]
|
index = result["index"]
|
||||||
print(f" {i+1}. Score: {score:.4f} | {content}...")
|
score = result["relevance_score"]
|
||||||
|
content = documents[index]
|
||||||
|
print(f" {index}. Score: {score:.4f} | {content}...")
|
||||||
|
i += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Rerank failed: {e}")
|
print(f"❌ Rerank failed: {e}")
|
||||||
|
|
@ -236,12 +204,12 @@ async def main():
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test rerank with different enable_rerank settings
|
|
||||||
await test_rerank_with_different_settings()
|
|
||||||
|
|
||||||
# Test direct rerank
|
# Test direct rerank
|
||||||
await test_direct_rerank()
|
await test_direct_rerank()
|
||||||
|
|
||||||
|
# Test rerank with different enable_rerank settings
|
||||||
|
await test_rerank_with_different_settings()
|
||||||
|
|
||||||
print("\n✅ Example completed successfully!")
|
print("\n✅ Example completed successfully!")
|
||||||
print("\n💡 Key Points:")
|
print("\n💡 Key Points:")
|
||||||
print(" ✓ Rerank is now controlled per query via 'enable_rerank' parameter")
|
print(" ✓ Rerank is now controlled per query via 'enable_rerank' parameter")
|
||||||
|
|
|
||||||
|
|
@ -469,8 +469,8 @@ class OllamaAPI:
|
||||||
"/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
|
"/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
|
||||||
)
|
)
|
||||||
async def chat(raw_request: Request):
|
async def chat(raw_request: Request):
|
||||||
"""Process chat completion requests acting as an Ollama model
|
"""Process chat completion requests by acting as an Ollama model.
|
||||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
Routes user queries through LightRAG by selecting query mode based on query prefix.
|
||||||
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
||||||
Supports both application/json and application/octet-stream Content-Types.
|
Supports both application/json and application/octet-stream Content-Types.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue