Merge 23e7ffbe1c into 1b02684e2f
This commit is contained in:
commit
94a50e6afc
6 changed files with 183 additions and 10 deletions
|
|
@ -456,7 +456,7 @@ async def initialize_rag():
|
|||
|
||||
* 如果您想使用Hugging Face模型,只需要按如下方式设置LightRAG:
|
||||
|
||||
参见`lightrag_hf_demo.py`
|
||||
参见`lightrag_hf_demo.py`, `lightrag_sentence_transformers_demo.py`等示例代码。
|
||||
|
||||
```python
|
||||
# 使用Hugging Face模型初始化LightRAG
|
||||
|
|
@ -467,10 +467,9 @@ rag = LightRAG(
|
|||
# 使用Hugging Face嵌入函数
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
func=lambda texts: hf_embed(
|
||||
func=lambda texts: sentence_transformers_embed(
|
||||
texts,
|
||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -644,6 +643,7 @@ if __name__ == "__main__":
|
|||
* **Cohere / vLLM**: `cohere_rerank`
|
||||
* **Jina AI**: `jina_rerank`
|
||||
* **Aliyun阿里云**: `ali_rerank`
|
||||
* **Sentence Transformers**: `sentence_transformers_rerank`
|
||||
|
||||
您可以将这些函数之一注入到LightRAG对象的`rerank_model_func`属性中。这将使LightRAG的查询功能能够使用注入的函数对检索到的文本块进行重新排序。有关详细用法,请参阅`examples/rerank_example.py`文件。
|
||||
|
||||
|
|
|
|||
10
README.md
10
README.md
|
|
@ -452,7 +452,7 @@ async def initialize_rag():
|
|||
|
||||
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
||||
|
||||
See `lightrag_hf_demo.py`
|
||||
See `lightrag_hf_demo.py` & `lightrag_sentence_transformers_demo.py` for complete examples.
|
||||
|
||||
```python
|
||||
# Initialize LightRAG with Hugging Face model
|
||||
|
|
@ -460,13 +460,12 @@ rag = LightRAG(
|
|||
working_dir=WORKING_DIR,
|
||||
llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
|
||||
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
|
||||
# Use Hugging Face embedding function
|
||||
# Use Hugging Face Sentence Transformers embedding function
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
func=lambda texts: hf_embed(
|
||||
func=lambda texts: sentence_transformers_embed(
|
||||
texts,
|
||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -642,6 +641,7 @@ To enhance retrieval quality, documents can be re-ranked based on a more effecti
|
|||
* **Cohere / vLLM**: `cohere_rerank`
|
||||
* **Jina AI**: `jina_rerank`
|
||||
* **Aliyun**: `ali_rerank`
|
||||
* **Sentence Transformers**: `sentence_transformers_rerank`
|
||||
|
||||
You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.hf import hf_model_complete
|
||||
from lightrag.llm.sentence_transformers import sentence_transformers_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=hf_model_complete,
|
||||
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
max_token_size=512,
|
||||
func=lambda texts: sentence_transformers_embed(
|
||||
texts,
|
||||
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages() # Auto-initializes pipeline_status
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||
)
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -784,6 +784,17 @@ def create_app(args):
|
|||
base_url=host,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif binding == "sentence_transformers":
|
||||
from lightrag.llm.sentence_transformers import (
|
||||
sentence_transformers_embed,
|
||||
)
|
||||
|
||||
actual_func = (
|
||||
sentence_transformers_embed.func
|
||||
if isinstance(sentence_transformers_embed, EmbeddingFunc)
|
||||
else sentence_transformers_embed
|
||||
)
|
||||
return await actual_func(texts, embedding_dim=embedding_dim)
|
||||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
|
|
@ -935,13 +946,19 @@ def create_app(args):
|
|||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
if args.rerank_binding != "null":
|
||||
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
|
||||
from lightrag.rerank import (
|
||||
cohere_rerank,
|
||||
jina_rerank,
|
||||
ali_rerank,
|
||||
sentence_transformers_rerank,
|
||||
)
|
||||
|
||||
# Map rerank binding to corresponding function
|
||||
rerank_functions = {
|
||||
"cohere": cohere_rerank,
|
||||
"jina": jina_rerank,
|
||||
"aliyun": ali_rerank,
|
||||
"sentence_transformers": sentence_transformers_rerank,
|
||||
}
|
||||
|
||||
# Select the appropriate rerank function based on binding
|
||||
|
|
|
|||
32
lightrag/llm/sentence_transformers.py
Normal file
32
lightrag/llm/sentence_transformers.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
if not pm.is_installed("sentence_transformers"):
|
||||
pm.install("sentence_transformers")
|
||||
if not pm.is_installed("numpy"):
|
||||
pm.install("numpy")
|
||||
|
||||
import numpy as np
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
async def sentence_transformers_embed(
|
||||
texts: list[str], model: SentenceTransformer
|
||||
) -> np.ndarray:
|
||||
async def inner_encode(
|
||||
texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024
|
||||
):
|
||||
return model.encode(
|
||||
texts,
|
||||
truncate_dim=embedding_dim,
|
||||
convert_to_numpy=True,
|
||||
convert_to_tensor=False,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=model.get_sentence_embedding_dimension(),
|
||||
func=inner_encode,
|
||||
max_token_size=model.get_max_seq_length(),
|
||||
)
|
||||
return await embedding_func(texts, model=model)
|
||||
|
|
@ -290,6 +290,40 @@ async def ali_rerank(
|
|||
)
|
||||
|
||||
|
||||
async def sentence_transformers_rerank(
|
||||
query: str,
|
||||
documents: List[str],
|
||||
top_n: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "BAAI/bge-reranker-v2-m3",
|
||||
base_url: Optional[str] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using CrossEncoder from Sentence Transformers.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of strings to rerank
|
||||
top_n: Number of top results to return
|
||||
api_key: Unused
|
||||
model: rerank model name
|
||||
base_url: Unused
|
||||
extra_body: Unused
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
"""
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
cross_encoder = CrossEncoder(model)
|
||||
rankings = cross_encoder.rank(query=query, documents=documents, top_k=top_n)
|
||||
return [
|
||||
{"index": result["corpus_id"], "relevance_score": result["score"]}
|
||||
for result in rankings
|
||||
]
|
||||
|
||||
|
||||
"""Please run this test as a module:
|
||||
python -m lightrag.rerank
|
||||
"""
|
||||
|
|
@ -351,4 +385,19 @@ if __name__ == "__main__":
|
|||
except Exception as e:
|
||||
print(f"Aliyun Error: {e}")
|
||||
|
||||
# Test Sentence Transformers rerank
|
||||
try:
|
||||
print("\n=== Sentence Transformers Rerank ===")
|
||||
result = await sentence_transformers_rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=2,
|
||||
)
|
||||
print("Results:")
|
||||
for item in result:
|
||||
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||
print(f"Document: {docs[item['index']]}")
|
||||
except Exception as e:
|
||||
print(f"Sentence Transformers Error: {e}")
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue