Add embeddings & reranking via Sentence Transformers
This commit is contained in:
parent
dfbc97363c
commit
835edda6fc
6 changed files with 177 additions and 10 deletions
|
|
@ -453,7 +453,7 @@ async def initialize_rag():
|
||||||
|
|
||||||
* 如果您想使用Hugging Face模型,只需要按如下方式设置LightRAG:
|
* 如果您想使用Hugging Face模型,只需要按如下方式设置LightRAG:
|
||||||
|
|
||||||
参见`lightrag_hf_demo.py`
|
参见`lightrag_hf_demo.py`, `lightrag_sentence_transformers_demo.py`等示例代码。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 使用Hugging Face模型初始化LightRAG
|
# 使用Hugging Face模型初始化LightRAG
|
||||||
|
|
@ -464,10 +464,9 @@ rag = LightRAG(
|
||||||
# 使用Hugging Face嵌入函数
|
# 使用Hugging Face嵌入函数
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=384,
|
embedding_dim=384,
|
||||||
func=lambda texts: hf_embed(
|
func=lambda texts: sentence_transformers_embed(
|
||||||
texts,
|
texts,
|
||||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -635,6 +634,7 @@ if __name__ == "__main__":
|
||||||
* **Cohere / vLLM**: `cohere_rerank`
|
* **Cohere / vLLM**: `cohere_rerank`
|
||||||
* **Jina AI**: `jina_rerank`
|
* **Jina AI**: `jina_rerank`
|
||||||
* **Aliyun阿里云**: `ali_rerank`
|
* **Aliyun阿里云**: `ali_rerank`
|
||||||
|
* **Sentence Transformers**: `sentence_transformers_rerank`
|
||||||
|
|
||||||
您可以将这些函数之一注入到LightRAG对象的`rerank_model_func`属性中。这将使LightRAG的查询功能能够使用注入的函数对检索到的文本块进行重新排序。有关详细用法,请参阅`examples/rerank_example.py`文件。
|
您可以将这些函数之一注入到LightRAG对象的`rerank_model_func`属性中。这将使LightRAG的查询功能能够使用注入的函数对检索到的文本块进行重新排序。有关详细用法,请参阅`examples/rerank_example.py`文件。
|
||||||
|
|
||||||
|
|
|
||||||
10
README.md
10
README.md
|
|
@ -449,7 +449,7 @@ async def initialize_rag():
|
||||||
|
|
||||||
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
||||||
|
|
||||||
See `lightrag_hf_demo.py`
|
See `lightrag_hf_demo.py` & `lightrag_sentence_transformers_demo.py` for complete examples.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Initialize LightRAG with Hugging Face model
|
# Initialize LightRAG with Hugging Face model
|
||||||
|
|
@ -457,13 +457,12 @@ rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
|
llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
|
||||||
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
|
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
|
||||||
# Use Hugging Face embedding function
|
# Use Hugging Face Sentence Transformers embedding function
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=384,
|
embedding_dim=384,
|
||||||
func=lambda texts: hf_embed(
|
func=lambda texts: sentence_transformers_embed(
|
||||||
texts,
|
texts,
|
||||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -633,6 +632,7 @@ To enhance retrieval quality, documents can be re-ranked based on a more effecti
|
||||||
* **Cohere / vLLM**: `cohere_rerank`
|
* **Cohere / vLLM**: `cohere_rerank`
|
||||||
* **Jina AI**: `jina_rerank`
|
* **Jina AI**: `jina_rerank`
|
||||||
* **Aliyun**: `ali_rerank`
|
* **Aliyun**: `ali_rerank`
|
||||||
|
* **Sentence Transformers**: `sentence_transformers_rerank`
|
||||||
|
|
||||||
You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file.
|
You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm.hf import hf_model_complete
|
||||||
|
from lightrag.llm.sentence_transformers import sentence_transformers_embed
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_rag():
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=hf_model_complete,
|
||||||
|
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=384,
|
||||||
|
max_token_size=512,
|
||||||
|
func=lambda texts: sentence_transformers_embed(
|
||||||
|
texts,
|
||||||
|
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag.initialize_storages() # Auto-initializes pipeline_status
|
||||||
|
return rag
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
rag = asyncio.run(initialize_rag())
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform local search
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform global search
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform hybrid search
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -781,6 +781,17 @@ def create_app(args):
|
||||||
base_url=host,
|
base_url=host,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
elif binding == "sentence_transformers":
|
||||||
|
from lightrag.llm.sentence_transformers import (
|
||||||
|
sentence_transformers_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_func = (
|
||||||
|
sentence_transformers_embed.func
|
||||||
|
if isinstance(sentence_transformers_embed, EmbeddingFunc)
|
||||||
|
else sentence_transformers_embed
|
||||||
|
)
|
||||||
|
return await actual_func(texts, embedding_dim=embedding_dim)
|
||||||
elif binding == "gemini":
|
elif binding == "gemini":
|
||||||
from lightrag.llm.gemini import gemini_embed
|
from lightrag.llm.gemini import gemini_embed
|
||||||
|
|
||||||
|
|
@ -932,13 +943,19 @@ def create_app(args):
|
||||||
# Configure rerank function based on args.rerank_bindingparameter
|
# Configure rerank function based on args.rerank_bindingparameter
|
||||||
rerank_model_func = None
|
rerank_model_func = None
|
||||||
if args.rerank_binding != "null":
|
if args.rerank_binding != "null":
|
||||||
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
|
from lightrag.rerank import (
|
||||||
|
cohere_rerank,
|
||||||
|
jina_rerank,
|
||||||
|
ali_rerank,
|
||||||
|
sentence_transformers_rerank,
|
||||||
|
)
|
||||||
|
|
||||||
# Map rerank binding to corresponding function
|
# Map rerank binding to corresponding function
|
||||||
rerank_functions = {
|
rerank_functions = {
|
||||||
"cohere": cohere_rerank,
|
"cohere": cohere_rerank,
|
||||||
"jina": jina_rerank,
|
"jina": jina_rerank,
|
||||||
"aliyun": ali_rerank,
|
"aliyun": ali_rerank,
|
||||||
|
"sentence_transformers": sentence_transformers_rerank,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Select the appropriate rerank function based on binding
|
# Select the appropriate rerank function based on binding
|
||||||
|
|
|
||||||
26
lightrag/llm/sentence_transformers.py
Normal file
26
lightrag/llm/sentence_transformers.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
import pipmaster as pm # Pipmaster for dynamic library install
|
||||||
|
|
||||||
|
if not pm.is_installed("sentence_transformers"):
|
||||||
|
pm.install("sentence_transformers")
|
||||||
|
if not pm.is_installed("numpy"):
|
||||||
|
pm.install("numpy")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
|
||||||
|
async def sentence_transformers_embed(
|
||||||
|
texts: list[str], model: SentenceTransformer
|
||||||
|
) -> np.ndarray:
|
||||||
|
async def inner_encode(texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024):
|
||||||
|
return model.encode(
|
||||||
|
texts,
|
||||||
|
truncate_dim=embedding_dim,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
convert_to_tensor=False,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_func = EmbeddingFunc(embedding_dim=model.get_sentence_embedding_dimension(), func=inner_encode, max_token_size=model.get_max_seq_length())
|
||||||
|
return await embedding_func(texts, model=model)
|
||||||
|
|
@ -290,6 +290,40 @@ async def ali_rerank(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def sentence_transformers_rerank(
|
||||||
|
query: str,
|
||||||
|
documents: List[str],
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "BAAI/bge-reranker-v2-m3",
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Rerank documents using CrossEncoder from Sentence Transformers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
documents: List of strings to rerank
|
||||||
|
top_n: Number of top results to return
|
||||||
|
api_key: Unused
|
||||||
|
model: rerank model name
|
||||||
|
base_url: Unused
|
||||||
|
extra_body: Unused
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
|
"""
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
cross_encoder = CrossEncoder(model)
|
||||||
|
rankings = cross_encoder.rank(query=query, documents=documents, top_k=top_n)
|
||||||
|
return [
|
||||||
|
{"index": result["corpus_id"], "relevance_score": result["score"]}
|
||||||
|
for result in rankings
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
"""Please run this test as a module:
|
"""Please run this test as a module:
|
||||||
python -m lightrag.rerank
|
python -m lightrag.rerank
|
||||||
"""
|
"""
|
||||||
|
|
@ -351,4 +385,19 @@ if __name__ == "__main__":
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Aliyun Error: {e}")
|
print(f"Aliyun Error: {e}")
|
||||||
|
|
||||||
|
# Test Sentence Transformers rerank
|
||||||
|
try:
|
||||||
|
print("\n=== Sentence Transformers Rerank ===")
|
||||||
|
result = await sentence_transformers_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
top_n=2,
|
||||||
|
)
|
||||||
|
print("Results:")
|
||||||
|
for item in result:
|
||||||
|
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||||
|
print(f"Document: {docs[item['index']]}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Sentence Transformers Error: {e}")
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue