Add embeddings & reranking via Sentence Transformers

This commit is contained in:
Tom Aarsen 2025-11-18 12:18:56 +01:00
parent dfbc97363c
commit 835edda6fc
6 changed files with 177 additions and 10 deletions

View file

@ -453,7 +453,7 @@ async def initialize_rag():
* 如果您想使用Hugging Face模型只需要按如下方式设置LightRAG * 如果您想使用Hugging Face模型只需要按如下方式设置LightRAG
参见`lightrag_hf_demo.py` 参见`lightrag_hf_demo.py`, `lightrag_sentence_transformers_demo.py`等示例代码。
```python ```python
# 使用Hugging Face模型初始化LightRAG # 使用Hugging Face模型初始化LightRAG
@ -464,10 +464,9 @@ rag = LightRAG(
# 使用Hugging Face嵌入函数 # 使用Hugging Face嵌入函数
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=384, embedding_dim=384,
func=lambda texts: hf_embed( func=lambda texts: sentence_transformers_embed(
texts, texts,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
) )
), ),
) )
@ -635,6 +634,7 @@ if __name__ == "__main__":
* **Cohere / vLLM**: `cohere_rerank` * **Cohere / vLLM**: `cohere_rerank`
* **Jina AI**: `jina_rerank` * **Jina AI**: `jina_rerank`
* **Aliyun阿里云**: `ali_rerank` * **Aliyun阿里云**: `ali_rerank`
* **Sentence Transformers**: `sentence_transformers_rerank`
您可以将这些函数之一注入到LightRAG对象的`rerank_model_func`属性中。这将使LightRAG的查询功能能够使用注入的函数对检索到的文本块进行重新排序。有关详细用法请参阅`examples/rerank_example.py`文件。 您可以将这些函数之一注入到LightRAG对象的`rerank_model_func`属性中。这将使LightRAG的查询功能能够使用注入的函数对检索到的文本块进行重新排序。有关详细用法请参阅`examples/rerank_example.py`文件。

View file

@ -449,7 +449,7 @@ async def initialize_rag():
* If you want to use Hugging Face models, you only need to set LightRAG as follows: * If you want to use Hugging Face models, you only need to set LightRAG as follows:
See `lightrag_hf_demo.py` See `lightrag_hf_demo.py` & `lightrag_sentence_transformers_demo.py` for complete examples.
```python ```python
# Initialize LightRAG with Hugging Face model # Initialize LightRAG with Hugging Face model
@ -457,13 +457,12 @@ rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=hf_model_complete, # Use Hugging Face model for text generation llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
# Use Hugging Face embedding function # Use Hugging Face Sentence Transformers embedding function
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=384, embedding_dim=384,
func=lambda texts: hf_embed( func=lambda texts: sentence_transformers_embed(
texts, texts,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
) )
), ),
) )
@ -633,6 +632,7 @@ To enhance retrieval quality, documents can be re-ranked based on a more effecti
* **Cohere / vLLM**: `cohere_rerank` * **Cohere / vLLM**: `cohere_rerank`
* **Jina AI**: `jina_rerank` * **Jina AI**: `jina_rerank`
* **Aliyun**: `ali_rerank` * **Aliyun**: `ali_rerank`
* **Sentence Transformers**: `sentence_transformers_rerank`
You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file. You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file.

View file

@ -0,0 +1,75 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.hf import hf_model_complete
from lightrag.llm.sentence_transformers import sentence_transformers_embed
from lightrag.utils import EmbeddingFunc
from sentence_transformers import SentenceTransformer
import asyncio
import nest_asyncio
nest_asyncio.apply()
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete,
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=512,
func=lambda texts: sentence_transformers_embed(
texts,
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2"),
),
),
)
await rag.initialize_storages() # Auto-initializes pipeline_status
return rag
def main():
rag = asyncio.run(initialize_rag())
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
# Perform local search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
# Perform global search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
# Perform hybrid search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
if __name__ == "__main__":
main()

View file

@ -781,6 +781,17 @@ def create_app(args):
base_url=host, base_url=host,
api_key=api_key, api_key=api_key,
) )
elif binding == "sentence_transformers":
from lightrag.llm.sentence_transformers import (
sentence_transformers_embed,
)
actual_func = (
sentence_transformers_embed.func
if isinstance(sentence_transformers_embed, EmbeddingFunc)
else sentence_transformers_embed
)
return await actual_func(texts, embedding_dim=embedding_dim)
elif binding == "gemini": elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed from lightrag.llm.gemini import gemini_embed
@ -932,13 +943,19 @@ def create_app(args):
# Configure rerank function based on args.rerank_bindingparameter # Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None rerank_model_func = None
if args.rerank_binding != "null": if args.rerank_binding != "null":
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank from lightrag.rerank import (
cohere_rerank,
jina_rerank,
ali_rerank,
sentence_transformers_rerank,
)
# Map rerank binding to corresponding function # Map rerank binding to corresponding function
rerank_functions = { rerank_functions = {
"cohere": cohere_rerank, "cohere": cohere_rerank,
"jina": jina_rerank, "jina": jina_rerank,
"aliyun": ali_rerank, "aliyun": ali_rerank,
"sentence_transformers": sentence_transformers_rerank,
} }
# Select the appropriate rerank function based on binding # Select the appropriate rerank function based on binding

View file

@ -0,0 +1,26 @@
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("sentence_transformers"):
pm.install("sentence_transformers")
if not pm.is_installed("numpy"):
pm.install("numpy")
import numpy as np
from lightrag.utils import EmbeddingFunc
from sentence_transformers import SentenceTransformer
async def sentence_transformers_embed(
texts: list[str], model: SentenceTransformer
) -> np.ndarray:
async def inner_encode(texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024):
return model.encode(
texts,
truncate_dim=embedding_dim,
convert_to_numpy=True,
convert_to_tensor=False,
show_progress_bar=False,
)
embedding_func = EmbeddingFunc(embedding_dim=model.get_sentence_embedding_dimension(), func=inner_encode, max_token_size=model.get_max_seq_length())
return await embedding_func(texts, model=model)

View file

@ -290,6 +290,40 @@ async def ali_rerank(
) )
async def sentence_transformers_rerank(
query: str,
documents: List[str],
top_n: Optional[int] = None,
api_key: Optional[str] = None,
model: str = "BAAI/bge-reranker-v2-m3",
base_url: Optional[str] = None,
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using CrossEncoder from Sentence Transformers.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: Unused
model: rerank model name
base_url: Unused
extra_body: Unused
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder(model)
rankings = cross_encoder.rank(query=query, documents=documents, top_k=top_n)
return [
{"index": result["corpus_id"], "relevance_score": result["score"]}
for result in rankings
]
"""Please run this test as a module: """Please run this test as a module:
python -m lightrag.rerank python -m lightrag.rerank
""" """
@ -350,5 +384,20 @@ if __name__ == "__main__":
print(f"Document: {docs[item['index']]}") print(f"Document: {docs[item['index']]}")
except Exception as e: except Exception as e:
print(f"Aliyun Error: {e}") print(f"Aliyun Error: {e}")
# Test Sentence Transformers rerank
try:
print("\n=== Sentence Transformers Rerank ===")
result = await sentence_transformers_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Sentence Transformers Error: {e}")
asyncio.run(main()) asyncio.run(main())