This commit is contained in:
hzywhite 2025-09-04 10:27:38 +08:00
parent bd533783e1
commit e27031587d
34 changed files with 2906 additions and 1768 deletions

View file

@ -357,7 +357,7 @@ API 服务器可以通过三种方式配置(优先级从高到低):
LightRAG 支持绑定到各种 LLM/嵌入后端: LightRAG 支持绑定到各种 LLM/嵌入后端:
* ollama * ollama
* openai 和 openai 兼容 * openai (含openai 兼容)
* azure_openai * azure_openai
* lollms * lollms
* aws_bedrock * aws_bedrock
@ -372,7 +372,10 @@ lightrag-server --llm-binding ollama --help
lightrag-server --embedding-binding ollama --help lightrag-server --embedding-binding ollama --help
``` ```
> 请使用openai兼容方式访问OpenRouter或vLLM部署的LLM。可以通过 `OPENAI_LLM_EXTRA_BODY` 环境变量给OpenRouter或vLLM传递额外的参数实现推理模式的关闭或者其它个性化控制。
### 实体提取配置 ### 实体提取配置
* ENABLE_LLM_CACHE_FOR_EXTRACT为实体提取启用 LLM 缓存默认true * ENABLE_LLM_CACHE_FOR_EXTRACT为实体提取启用 LLM 缓存默认true
在测试环境中将 `ENABLE_LLM_CACHE_FOR_EXTRACT` 设置为 true 以减少 LLM 调用成本是很常见的做法。 在测试环境中将 `ENABLE_LLM_CACHE_FOR_EXTRACT` 设置为 true 以减少 LLM 调用成本是很常见的做法。
@ -386,51 +389,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
* GRAPH_STORAGE实体关系图 * GRAPH_STORAGE实体关系图
* DOC_STATUS_STORAGE文档索引状态 * DOC_STATUS_STORAGE文档索引状态
每种存储类型都有几种实现: 每种存储类型都有多种存储实现方式。LightRAG Server默认的存储实现为内存数据库数据通过文件持久化保存到WORKING_DIR目录。LightRAG还支持PostgreSQL、MongoDB、FAISS、Milvus、Qdrant、Neo4j、Memgraph和Redis等存储实现方式。详细的存储支持方式请参考根目录下的`README.md`文件中关于存储的相关内容。
* KV_STORAGE 支持的实现名称 您可以通过环境变量选择存储实现。例如,在首次启动 API 服务器之前,您可以将以下环境变量设置为特定的存储实现名称:
```
JsonKVStorage JsonFile(默认)
PGKVStorage Postgres
RedisKVStorage Redis
MongoKVStorage MogonDB
```
* GRAPH_STORAGE 支持的实现名称
```
NetworkXStorage NetworkX(默认)
Neo4JStorage Neo4J
PGGraphStorage PostgreSQL with AGE plugin
```
> 在测试中Neo4j图形数据库相比PostgreSQL AGE有更好的性能表现。
* VECTOR_STORAGE 支持的实现名称
```
NanoVectorDBStorage NanoVector(默认)
PGVectorStorage Postgres
MilvusVectorDBStorge Milvus
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
MongoVectorDBStorage MongoDB
```
* DOC_STATUS_STORAGE 支持的实现名称
```
JsonDocStatusStorage JsonFile(默认)
PGDocStatusStorage Postgres
MongoDocStatusStorage MongoDB
```
每一种存储类型的链接配置范例可以在 `env.example` 文件中找到。链接字符串中的数据库实例是需要你预先在数据库服务器上创建好的LightRAG 仅负责在数据库实例中创建数据表,不负责创建数据库实例。如果使用 Redis 作为存储,记得给 Redis 配置自动持久化数据规则,否则 Redis 服务重启后数据会丢失。如果使用PostgreSQL数据库推荐使用16.6版本或以上。
### 如何选择存储实现
您可以通过环境变量选择存储实现。在首次启动 API 服务器之前,您可以将以下环境变量设置为特定的存储实现名称:
``` ```
LIGHTRAG_KV_STORAGE=PGKVStorage LIGHTRAG_KV_STORAGE=PGKVStorage
@ -439,7 +400,7 @@ LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
``` ```
在向 LightRAG 添加文档后,您不能更改存储实现选择。目前尚不支持从一个存储实现迁移到另一个存储实现。更多信息请阅读示例 env 文件或 config.ini 文件。 在向 LightRAG 添加文档后,您不能更改存储实现选择。目前尚不支持从一个存储实现迁移到另一个存储实现。更多配置信息请阅读示例 `env.exampl`e文件。
### LightRag API 服务器命令行选项 ### LightRag API 服务器命令行选项
@ -450,20 +411,54 @@ LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
| --working-dir | ./rag_storage | RAG 存储的工作目录 | | --working-dir | ./rag_storage | RAG 存储的工作目录 |
| --input-dir | ./inputs | 包含输入文档的目录 | | --input-dir | ./inputs | 包含输入文档的目录 |
| --max-async | 4 | 最大异步操作数 | | --max-async | 4 | 最大异步操作数 |
| --max-tokens | 32768 | 最大 token 大小 |
| --timeout | 150 | 超时时间。None 表示无限超时(不推荐) |
| --log-level | INFO | 日志级别DEBUG、INFO、WARNING、ERROR、CRITICAL | | --log-level | INFO | 日志级别DEBUG、INFO、WARNING、ERROR、CRITICAL |
| --verbose | - | 详细调试输出True、False | | --verbose | - | 详细调试输出True、False |
| --key | None | 用于认证的 API 密钥。保护 lightrag 服务器免受未授权访问 | | --key | None | 用于认证的 API 密钥。保护 lightrag 服务器免受未授权访问 |
| --ssl | False | 启用 HTTPS | | --ssl | False | 启用 HTTPS |
| --ssl-certfile | None | SSL 证书文件路径(如果启用 --ssl 则必需) | | --ssl-certfile | None | SSL 证书文件路径(如果启用 --ssl 则必需) |
| --ssl-keyfile | None | SSL 私钥文件路径(如果启用 --ssl 则必需) | | --ssl-keyfile | None | SSL 私钥文件路径(如果启用 --ssl 则必需) |
| --top-k | 50 | 要检索的 top-k 项目数;在"local"模式下对应实体,在"global"模式下对应关系。 | | --llm-binding | ollama | LLM 绑定类型lollms、ollama、openai、openai-ollama、azure_openai、aws_bedrock |
| --cosine-threshold | 0.4 | 节点和关系检索的余弦阈值,与 top-k 一起控制节点和关系的检索。 | | --embedding-binding | ollama | 嵌入绑定类型lollms、ollama、openai、azure_openai、aws_bedrock |
| --llm-binding | ollama | LLM 绑定类型lollms、ollama、openai、openai-ollama、azure_openai |
| --embedding-binding | ollama | 嵌入绑定类型lollms、ollama、openai、azure_openai |
| auto-scan-at-startup | - | 扫描输入目录中的新文件并开始索引 | | auto-scan-at-startup | - | 扫描输入目录中的新文件并开始索引 |
### Reranking 配置
Reranking 查询召回的块可以显著提高检索质量它通过基于优化的相关性评分模型对文档重新排序。LightRAG 目前支持以下 rerank 提供商:
- **Cohere / vLLM**:提供与 Cohere AI 的 `v2/rerank` 端点的完整 API 集成。由于 vLLM 提供了与 Cohere 兼容的 reranker API因此也支持所有通过 vLLM 部署的 reranker 模型。
- **Jina AI**:提供与所有 Jina rerank 模型的完全实现兼容性。
- **阿里云**:具有旨在支持阿里云 rerank API 格式的自定义实现。
Rerank 提供商通过 `.env` 文件进行配置。以下是使用 vLLM 本地部署的 rerank 模型的示例配置:
```
RERANK_BINDING=cohere
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_BINDING_HOST=http://localhost:8000/v1/rerank
RERANK_BINDING_API_KEY=your_rerank_api_key_here
```
以下是使用阿里云提供的 Reranker 服务的示例配置:
```
RERANK_BINDING=aliyun
RERANK_MODEL=gte-rerank-v2
RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
RERANK_BINDING_API_KEY=your_rerank_api_key_here
```
有关完整的 reranker 配置示例,请参阅 `env.example` 文件。
### 启用 Reranking
可以按查询启用或禁用 Reranking。
`/query``/query/stream` API 端点包含一个 `enable_rerank` 参数,默认设置为 `true`,用于控制当前查询是否激活 reranking。要将 `enable_rerank` 参数的默认值更改为 `false`,请设置以下环境变量:
```
RERANK_BY_DEFAULT=False
```
### .env 文件示例 ### .env 文件示例
```bash ```bash
@ -478,7 +473,7 @@ SUMMARY_LANGUAGE=Chinese
MAX_PARALLEL_INSERT=2 MAX_PARALLEL_INSERT=2
### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal) ### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
TIMEOUT=200 TIMEOUT=150
MAX_ASYNC=4 MAX_ASYNC=4
LLM_BINDING=openai LLM_BINDING=openai

View file

@ -360,7 +360,7 @@ Most of the configurations come with default settings; check out the details in
LightRAG supports binding to various LLM/Embedding backends: LightRAG supports binding to various LLM/Embedding backends:
* ollama * ollama
* openai & openai compatible * openai (including openai compatible)
* azure_openai * azure_openai
* lollms * lollms
* aws_bedrock * aws_bedrock
@ -374,6 +374,8 @@ lightrag-server --llm-binding ollama --help
lightrag-server --embedding-binding ollama --help lightrag-server --embedding-binding ollama --help
``` ```
> Please use OpenAI-compatible method to access LLMs deployed by OpenRouter or vLLM. You can pass additional parameters to OpenRouter or vLLM through the `OPENAI_LLM_EXTRA_BODY` environment variable to disable reasoning mode or achieve other personalized controls.
### Entity Extraction Configuration ### Entity Extraction Configuration
* ENABLE_LLM_CACHE_FOR_EXTRACT: Enable LLM cache for entity extraction (default: true) * ENABLE_LLM_CACHE_FOR_EXTRACT: Enable LLM cache for entity extraction (default: true)
@ -388,52 +390,9 @@ LightRAG uses 4 types of storage for different purposes:
* GRAPH_STORAGE: entity relation graph * GRAPH_STORAGE: entity relation graph
* DOC_STATUS_STORAGE: document indexing status * DOC_STATUS_STORAGE: document indexing status
Each storage type has several implementations: LightRAG Server offers various storage implementations, with the default being an in-memory database that persists data to the WORKING_DIR directory. Additionally, LightRAG supports a wide range of storage solutions including PostgreSQL, MongoDB, FAISS, Milvus, Qdrant, Neo4j, Memgraph, and Redis. For detailed information on supported storage options, please refer to the storage section in the README.md file located in the root directory.
* KV_STORAGE supported implementations: You can select the storage implementation by configuring environment variables. For instance, prior to the initial launch of the API server, you can set the following environment variable to specify your desired storage implementation:
```
JsonKVStorage JsonFile (default)
PGKVStorage Postgres
RedisKVStorage Redis
MongoKVStorage MongoDB
```
* GRAPH_STORAGE supported implementations:
```
NetworkXStorage NetworkX (default)
Neo4JStorage Neo4J
PGGraphStorage PostgreSQL with AGE plugin
MemgraphStorage. Memgraph
```
> Testing has shown that Neo4J delivers superior performance in production environments compared to PostgreSQL with AGE plugin.
* VECTOR_STORAGE supported implementations:
```
NanoVectorDBStorage NanoVector (default)
PGVectorStorage Postgres
MilvusVectorDBStorage Milvus
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
MongoVectorDBStorage MongoDB
```
* DOC_STATUS_STORAGE: supported implementations:
```
JsonDocStatusStorage JsonFile (default)
PGDocStatusStorage Postgres
MongoDocStatusStorage MongoDB
```
Example connection configurations for each storage type can be found in the `env.example` file. The database instance in the connection string needs to be created by you on the database server beforehand. LightRAG is only responsible for creating tables within the database instance, not for creating the database instance itself. If using Redis as storage, remember to configure automatic data persistence rules for Redis, otherwise data will be lost after the Redis service restarts. If using PostgreSQL, it is recommended to use version 16.6 or above.
### How to Select Storage Implementation
You can select storage implementation by environment variables. You can set the following environment variables to a specific storage implementation name before the first start of the API Server:
``` ```
LIGHTRAG_KV_STORAGE=PGKVStorage LIGHTRAG_KV_STORAGE=PGKVStorage
@ -453,23 +412,53 @@ You cannot change storage implementation selection after adding documents to Lig
| --working-dir | ./rag_storage | Working directory for RAG storage | | --working-dir | ./rag_storage | Working directory for RAG storage |
| --input-dir | ./inputs | Directory containing input documents | | --input-dir | ./inputs | Directory containing input documents |
| --max-async | 4 | Maximum number of async operations | | --max-async | 4 | Maximum number of async operations |
| --max-tokens | 32768 | Maximum token size |
| --timeout | 150 | Timeout in seconds. None for infinite timeout (not recommended) |
| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | | --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| --verbose | - | Verbose debug output (True, False) | | --verbose | - | Verbose debug output (True, False) |
| --key | None | API key for authentication. Protects the LightRAG server against unauthorized access | | --key | None | API key for authentication. Protects the LightRAG server against unauthorized access |
| --ssl | False | Enable HTTPS | | --ssl | False | Enable HTTPS |
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | | --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
| --cosine-threshold | 0.4 | The cosine threshold for nodes and relation retrieval, works with top-k to control the retrieval of nodes and relations. |
| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai, aws_bedrock) | | --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai, aws_bedrock) |
| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai, aws_bedrock) | | --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai, aws_bedrock) |
| --auto-scan-at-startup| - | Scan input directory for new files and start indexing | | --auto-scan-at-startup| - | Scan input directory for new files and start indexing |
### Additional Ollama Binding Options ### Reranking Configuration
When using `--llm-binding ollama` or `--embedding-binding ollama`, additional Ollama-specific configuration options are available. To see all available Ollama binding options, add `--help` to the command line when starting the server. These additional options allow for fine-tuning of Ollama model parameters and connection settings. Reranking query-recalled chunks can significantly enhance retrieval quality by re-ordering documents based on an optimized relevance scoring model. LightRAG currently supports the following rerank providers:
- **Cohere / vLLM**: Offers full API integration with Cohere AI's `v2/rerank` endpoint. As vLLM provides a Cohere-compatible reranker API, all reranker models deployed via vLLM are also supported.
- **Jina AI**: Provides complete implementation compatibility with all Jina rerank models.
- **Aliyun**: Features a custom implementation designed to support Aliyun's rerank API format.
The rerank provider is configured via the `.env` file. Below is an example configuration for a rerank model deployed locally using vLLM:
```
RERANK_BINDING=cohere
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_BINDING_HOST=http://localhost:8000/v1/rerank
RERANK_BINDING_API_KEY=your_rerank_api_key_here
```
Here is an example configuration for utilizing the Reranker service provided by Aliyun:
```
RERANK_BINDING=aliyun
RERANK_MODEL=gte-rerank-v2
RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
RERANK_BINDING_API_KEY=your_rerank_api_key_here
```
For comprehensive reranker configuration examples, please refer to the `env.example` file.
### Enable Reranking
Reranking can be enabled or disabled on a per-query basis.
The `/query` and `/query/stream` API endpoints include an `enable_rerank` parameter, which is set to `true` by default, controlling whether reranking is active for the current query. To change the default value of the `enable_rerank` parameter to `false`, set the following environment variable:
```
RERANK_BY_DEFAULT=False
```
### .env Examples ### .env Examples
@ -485,7 +474,7 @@ SUMMARY_LANGUAGE=Chinese
MAX_PARALLEL_INSERT=2 MAX_PARALLEL_INSERT=2
### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal) ### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
TIMEOUT=200 TIMEOUT=150
MAX_ASYNC=4 MAX_ASYNC=4
LLM_BINDING=openai LLM_BINDING=openai

View file

@ -1 +1 @@
__api_version__ = "0205" __api_version__ = "0213"

View file

@ -30,12 +30,15 @@ from lightrag.constants import (
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
DEFAULT_MAX_ASYNC, DEFAULT_MAX_ASYNC,
DEFAULT_SUMMARY_MAX_TOKENS, DEFAULT_SUMMARY_MAX_TOKENS,
DEFAULT_SUMMARY_LENGTH_RECOMMENDED,
DEFAULT_SUMMARY_CONTEXT_SIZE,
DEFAULT_SUMMARY_LANGUAGE, DEFAULT_SUMMARY_LANGUAGE,
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, DEFAULT_EMBEDDING_FUNC_MAX_ASYNC,
DEFAULT_EMBEDDING_BATCH_NUM, DEFAULT_EMBEDDING_BATCH_NUM,
DEFAULT_OLLAMA_MODEL_NAME, DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG, DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_TEMPERATURE, DEFAULT_RERANK_BINDING,
DEFAULT_ENTITY_TYPES,
) )
# use the .env that is inside the current folder # use the .env that is inside the current folder
@ -77,9 +80,7 @@ def parse_args() -> argparse.Namespace:
argparse.Namespace: Parsed arguments argparse.Namespace: Parsed arguments
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="LightRAG API Server")
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration # Server configuration
parser.add_argument( parser.add_argument(
@ -121,10 +122,26 @@ def parse_args() -> argparse.Namespace:
help=f"Maximum async operations (default: from env or {DEFAULT_MAX_ASYNC})", help=f"Maximum async operations (default: from env or {DEFAULT_MAX_ASYNC})",
) )
parser.add_argument( parser.add_argument(
"--max-tokens", "--summary-max-tokens",
type=int, type=int,
default=get_env_value("MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS, int), default=get_env_value("SUMMARY_MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS, int),
help=f"Maximum token size (default: from env or {DEFAULT_SUMMARY_MAX_TOKENS})", help=f"Maximum token size for entity/relation summary(default: from env or {DEFAULT_SUMMARY_MAX_TOKENS})",
)
parser.add_argument(
"--summary-context-size",
type=int,
default=get_env_value(
"SUMMARY_CONTEXT_SIZE", DEFAULT_SUMMARY_CONTEXT_SIZE, int
),
help=f"LLM Summary Context size (default: from env or {DEFAULT_SUMMARY_CONTEXT_SIZE})",
)
parser.add_argument(
"--summary-length-recommended",
type=int,
default=get_env_value(
"SUMMARY_LENGTH_RECOMMENDED", DEFAULT_SUMMARY_LENGTH_RECOMMENDED, int
),
help=f"LLM Summary Context size (default: from env or {DEFAULT_SUMMARY_LENGTH_RECOMMENDED})",
) )
# Logging configuration # Logging configuration
@ -226,6 +243,13 @@ def parse_args() -> argparse.Namespace:
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
help="Embedding binding type (default: from env or ollama)", help="Embedding binding type (default: from env or ollama)",
) )
parser.add_argument(
"--rerank-binding",
type=str,
default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING),
choices=["null", "cohere", "jina", "aliyun"],
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
)
# Conditionally add binding options defined in binding_options module # Conditionally add binding options defined in binding_options module
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx) # This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
@ -264,14 +288,6 @@ def parse_args() -> argparse.Namespace:
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]: elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
OpenAILLMOptions.add_args(parser) OpenAILLMOptions.add_args(parser)
# Add global temperature command line argument
parser.add_argument(
"--temperature",
type=float,
default=get_env_value("TEMPERATURE", DEFAULT_TEMPERATURE, float),
help="Global temperature setting for LLM (default: from env TEMPERATURE or 0.1)",
)
args = parser.parse_args() args = parser.parse_args()
# convert relative path to absolute path # convert relative path to absolute path
@ -330,38 +346,13 @@ def parse_args() -> argparse.Namespace:
) )
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
# Handle Ollama LLM temperature with priority cascade when llm-binding is ollama
if args.llm_binding == "ollama":
# Priority order (highest to lowest):
# 1. --ollama-llm-temperature command argument
# 2. OLLAMA_LLM_TEMPERATURE environment variable
# 3. --temperature command argument
# 4. TEMPERATURE environment variable
# Check if --ollama-llm-temperature was explicitly provided in command line
if "--ollama-llm-temperature" not in sys.argv:
# Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
args.ollama_llm_temperature = args.temperature
# Handle OpenAI LLM temperature with priority cascade when llm-binding is openai or azure_openai
if args.llm_binding in ["openai", "azure_openai"]:
# Priority order (highest to lowest):
# 1. --openai-llm-temperature command argument
# 2. OPENAI_LLM_TEMPERATURE environment variable
# 3. --temperature command argument
# 4. TEMPERATURE environment variable
# Check if --openai-llm-temperature was explicitly provided in command line
if "--openai-llm-temperature" not in sys.argv:
# Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
args.openai_llm_temperature = args.temperature
# Select Document loading tool (DOCLING, DEFAULT) # Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT") args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
# Add environment variables that were previously read directly # Add environment variables that were previously read directly
args.cors_origins = get_env_value("CORS_ORIGINS", "*") args.cors_origins = get_env_value("CORS_ORIGINS", "*")
args.summary_language = get_env_value("SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE) args.summary_language = get_env_value("SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE)
args.entity_types = get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES, list)
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*") args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
# For JWT Auth # For JWT Auth
@ -372,9 +363,10 @@ def parse_args() -> argparse.Namespace:
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Rerank model configuration # Rerank model configuration
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") args.rerank_model = get_env_value("RERANK_MODEL", None)
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
# Note: rerank_binding is already set by argparse, no need to override from env
# Min rerank score configuration # Min rerank score configuration
args.min_rerank_score = get_env_value( args.min_rerank_score = get_env_value(

View file

@ -2,7 +2,7 @@
LightRAG FastAPI Server LightRAG FastAPI Server
""" """
from fastapi import FastAPI, Depends, HTTPException, status from fastapi import FastAPI, Depends, HTTPException
import asyncio import asyncio
import os import os
import logging import logging
@ -11,6 +11,7 @@ import signal
import sys import sys
import uvicorn import uvicorn
import pipmaster as pm import pipmaster as pm
import inspect
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pathlib import Path from pathlib import Path
@ -38,6 +39,8 @@ from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES, DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT, DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME, DEFAULT_LOG_FILENAME,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
) )
from lightrag.api.routers.document_routes import ( from lightrag.api.routers.document_routes import (
DocumentManager, DocumentManager,
@ -236,25 +239,106 @@ def create_app(args):
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True) Path(args.working_dir).mkdir(parents=True, exist_ok=True)
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete, lollms_embed def create_llm_model_func(binding: str):
if args.llm_binding == "ollama" or args.embedding_binding == "ollama": """
from lightrag.llm.ollama import ollama_model_complete, ollama_embed Create LLM model function based on binding type.
from lightrag.llm.binding_options import OllamaLLMOptions Uses lazy import to avoid unnecessary dependencies.
if args.llm_binding == "openai" or args.embedding_binding == "openai": """
from lightrag.llm.openai import openai_complete_if_cache, openai_embed try:
from lightrag.llm.binding_options import OpenAILLMOptions if binding == "lollms":
if args.llm_binding == "azure_openai" or args.embedding_binding == "azure_openai": from lightrag.llm.lollms import lollms_model_complete
from lightrag.llm.azure_openai import (
azure_openai_complete_if_cache, return lollms_model_complete
azure_openai_embed, elif binding == "ollama":
) from lightrag.llm.ollama import ollama_model_complete
if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed return ollama_model_complete
if args.embedding_binding == "ollama": elif binding == "aws_bedrock":
from lightrag.llm.binding_options import OllamaEmbeddingOptions return bedrock_model_complete # Already defined locally
if args.embedding_binding == "jina": elif binding == "azure_openai":
from lightrag.llm.jina import jina_embed return azure_openai_model_complete # Already defined locally
else: # openai and compatible
return openai_alike_model_complete # Already defined locally
except ImportError as e:
raise Exception(f"Failed to import {binding} LLM binding: {e}")
def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
"""
Create LLM model kwargs based on binding type.
Uses lazy import for binding-specific options.
"""
if binding in ["lollms", "ollama"]:
try:
from lightrag.llm.binding_options import OllamaLLMOptions
return {
"host": args.llm_binding_host,
"timeout": llm_timeout,
"options": OllamaLLMOptions.options_dict(args),
"api_key": args.llm_binding_api_key,
}
except ImportError as e:
raise Exception(f"Failed to import {binding} options: {e}")
return {}
def create_embedding_function_with_lazy_import(
binding, model, host, api_key, dimensions, args
):
"""
Create embedding function with lazy imports for all bindings.
Replaces the current create_embedding_function with full lazy import support.
"""
async def embedding_function(texts):
try:
if binding == "lollms":
from lightrag.llm.lollms import lollms_embed
return await lollms_embed(
texts, embed_model=model, host=host, api_key=api_key
)
elif binding == "ollama":
from lightrag.llm.binding_options import OllamaEmbeddingOptions
from lightrag.llm.ollama import ollama_embed
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await ollama_embed(
texts,
embed_model=model,
host=host,
api_key=api_key,
options=ollama_options,
)
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
return await jina_embed(
texts, dimensions=dimensions, base_url=host, api_key=api_key
)
else: # openai and compatible
from lightrag.llm.openai import openai_embed
return await openai_embed(
texts, model=model, base_url=host, api_key=api_key
)
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")
return embedding_function
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
"EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
)
async def openai_alike_model_complete( async def openai_alike_model_complete(
prompt, prompt,
@ -263,18 +347,20 @@ def create_app(args):
keyword_extraction=False, keyword_extraction=False,
**kwargs, **kwargs,
) -> str: ) -> str:
# Lazy import
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.binding_options import OpenAILLMOptions
keyword_extraction = kwargs.pop("keyword_extraction", None) keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction: if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None: if history_messages is None:
history_messages = [] history_messages = []
# Use OpenAI LLM options if available, otherwise fallback to global temperature # Use OpenAI LLM options if available
if args.llm_binding == "openai": openai_options = OpenAILLMOptions.options_dict(args)
openai_options = OpenAILLMOptions.options_dict(args) kwargs["timeout"] = llm_timeout
kwargs.update(openai_options) kwargs.update(openai_options)
else:
kwargs["temperature"] = args.temperature
return await openai_complete_if_cache( return await openai_complete_if_cache(
args.llm_model, args.llm_model,
@ -293,18 +379,20 @@ def create_app(args):
keyword_extraction=False, keyword_extraction=False,
**kwargs, **kwargs,
) -> str: ) -> str:
# Lazy import
from lightrag.llm.azure_openai import azure_openai_complete_if_cache
from lightrag.llm.binding_options import OpenAILLMOptions
keyword_extraction = kwargs.pop("keyword_extraction", None) keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction: if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None: if history_messages is None:
history_messages = [] history_messages = []
# Use OpenAI LLM options if available, otherwise fallback to global temperature # Use OpenAI LLM options
if args.llm_binding == "azure_openai": openai_options = OpenAILLMOptions.options_dict(args)
openai_options = OpenAILLMOptions.options_dict(args) kwargs["timeout"] = llm_timeout
kwargs.update(openai_options) kwargs.update(openai_options)
else:
kwargs["temperature"] = args.temperature
return await azure_openai_complete_if_cache( return await azure_openai_complete_if_cache(
args.llm_model, args.llm_model,
@ -324,6 +412,9 @@ def create_app(args):
keyword_extraction=False, keyword_extraction=False,
**kwargs, **kwargs,
) -> str: ) -> str:
# Lazy import
from lightrag.llm.bedrock import bedrock_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None) keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction: if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat kwargs["response_format"] = GPTKeywordExtractionFormat
@ -331,7 +422,7 @@ def create_app(args):
history_messages = [] history_messages = []
# Use global temperature for Bedrock # Use global temperature for Bedrock
kwargs["temperature"] = args.temperature kwargs["temperature"] = get_env_value("BEDROCK_LLM_TEMPERATURE", 1.0, float)
return await bedrock_complete_if_cache( return await bedrock_complete_if_cache(
args.llm_model, args.llm_model,
@ -341,86 +432,73 @@ def create_app(args):
**kwargs, **kwargs,
) )
# Create embedding function with lazy imports
embedding_func = EmbeddingFunc( embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim, embedding_dim=args.embedding_dim,
func=lambda texts: ( func=create_embedding_function_with_lazy_import(
lollms_embed( binding=args.embedding_binding,
texts, model=args.embedding_model,
embed_model=args.embedding_model, host=args.embedding_binding_host,
host=args.embedding_binding_host, api_key=args.embedding_binding_api_key,
api_key=args.embedding_binding_api_key, dimensions=args.embedding_dim,
) args=args, # Pass args object for dynamic option generation
if args.embedding_binding == "lollms"
else (
ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
options=OllamaEmbeddingOptions.options_dict(args),
)
if args.embedding_binding == "ollama"
else (
azure_openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "azure_openai"
else (
bedrock_embed(
texts,
model=args.embedding_model,
)
if args.embedding_binding == "aws_bedrock"
else (
jina_embed(
texts,
dimensions=args.embedding_dim,
base_url=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "jina"
else openai_embed(
texts,
model=args.embedding_model,
base_url=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
)
)
)
)
)
), ),
) )
# Configure rerank function if model and API are configured # Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None rerank_model_func = None
if args.rerank_binding_api_key and args.rerank_binding_host: if args.rerank_binding != "null":
from lightrag.rerank import custom_rerank from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
# Map rerank binding to corresponding function
rerank_functions = {
"cohere": cohere_rerank,
"jina": jina_rerank,
"aliyun": ali_rerank,
}
# Select the appropriate rerank function based on binding
selected_rerank_func = rerank_functions.get(args.rerank_binding)
if not selected_rerank_func:
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
# Get default values from selected_rerank_func if args values are None
if args.rerank_model is None or args.rerank_binding_host is None:
sig = inspect.signature(selected_rerank_func)
# Set default model if args.rerank_model is None
if args.rerank_model is None and "model" in sig.parameters:
default_model = sig.parameters["model"].default
if default_model != inspect.Parameter.empty:
args.rerank_model = default_model
# Set default base_url if args.rerank_binding_host is None
if args.rerank_binding_host is None and "base_url" in sig.parameters:
default_base_url = sig.parameters["base_url"].default
if default_base_url != inspect.Parameter.empty:
args.rerank_binding_host = default_base_url
async def server_rerank_func( async def server_rerank_func(
query: str, documents: list, top_n: int = None, **kwargs query: str, documents: list, top_n: int = None, extra_body: dict = None
): ):
"""Server rerank function with configuration from environment variables""" """Server rerank function with configuration from environment variables"""
return await custom_rerank( return await selected_rerank_func(
query=query, query=query,
documents=documents, documents=documents,
top_n=top_n,
api_key=args.rerank_binding_api_key,
model=args.rerank_model, model=args.rerank_model,
base_url=args.rerank_binding_host, base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key, extra_body=extra_body,
top_n=top_n,
**kwargs,
) )
rerank_model_func = server_rerank_func rerank_model_func = server_rerank_func
logger.info( logger.info(
f"Rerank model configured: {args.rerank_model} (can be enabled per query)" f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
) )
else: else:
logger.info( logger.info("Reranking is disabled")
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
)
# Create ollama_server_infos from command line arguments # Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos from lightrag.api.config import OllamaServerInfos
@ -429,38 +507,24 @@ def create_app(args):
name=args.simulated_model_name, tag=args.simulated_model_tag name=args.simulated_model_name, tag=args.simulated_model_tag
) )
# Initialize RAG # Initialize RAG with unified configuration
if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: try:
rag = LightRAG( rag = LightRAG(
working_dir=args.working_dir, working_dir=args.working_dir,
workspace=args.workspace, workspace=args.workspace,
llm_model_func=( llm_model_func=create_llm_model_func(args.llm_binding),
lollms_model_complete
if args.llm_binding == "lollms"
else (
ollama_model_complete
if args.llm_binding == "ollama"
else bedrock_model_complete
if args.llm_binding == "aws_bedrock"
else openai_alike_model_complete
)
),
llm_model_name=args.llm_model, llm_model_name=args.llm_model,
llm_model_max_async=args.max_async, llm_model_max_async=args.max_async,
summary_max_tokens=args.max_tokens, summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size), chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size), chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=( llm_model_kwargs=create_llm_model_kwargs(
{ args.llm_binding, args, llm_timeout
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": OllamaLLMOptions.options_dict(args),
"api_key": args.llm_binding_api_key,
}
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {}
), ),
embedding_func=embedding_func, embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage, kv_storage=args.kv_storage,
graph_storage=args.graph_storage, graph_storage=args.graph_storage,
vector_storage=args.vector_storage, vector_storage=args.vector_storage,
@ -473,36 +537,10 @@ def create_app(args):
rerank_model_func=rerank_model_func, rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes, max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language}, addon_params={
ollama_server_infos=ollama_server_infos, "language": args.summary_language,
) "entity_types": args.entity_types,
else: # azure_openai
rag = LightRAG(
working_dir=args.working_dir,
workspace=args.workspace,
llm_model_func=azure_openai_model_complete,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs={
"timeout": args.timeout,
}, },
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.max_tokens,
embedding_func=embedding_func,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={"language": args.summary_language},
ollama_server_infos=ollama_server_infos, ollama_server_infos=ollama_server_infos,
) )
@ -709,9 +747,7 @@ def create_app(args):
} }
username = form_data.username username = form_data.username
if auth_handler.accounts.get(username) != form_data.password: if auth_handler.accounts.get(username) != form_data.password:
raise HTTPException( raise HTTPException(status_code=401, detail="Incorrect credentials")
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
# Regular user login # Regular user login
user_token = auth_handler.create_token( user_token = auth_handler.create_token(
@ -754,7 +790,8 @@ def create_app(args):
"embedding_binding": args.embedding_binding, "embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host, "embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model, "embedding_model": args.embedding_model,
"max_tokens": args.max_tokens, "summary_max_tokens": args.summary_max_tokens,
"summary_context_size": args.summary_context_size,
"kv_storage": args.kv_storage, "kv_storage": args.kv_storage,
"doc_status_storage": args.doc_status_storage, "doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage, "graph_storage": args.graph_storage,
@ -763,13 +800,12 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace, "workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes, "max_graph_nodes": args.max_graph_nodes,
# Rerank configuration (based on whether rerank model is configured) # Rerank configuration
"enable_rerank": rerank_model_func is not None, "enable_rerank": rerank_model_func is not None,
"rerank_model": args.rerank_model "rerank_binding": args.rerank_binding,
if rerank_model_func is not None "rerank_model": args.rerank_model if rerank_model_func else None,
else None,
"rerank_binding_host": args.rerank_binding_host "rerank_binding_host": args.rerank_binding_host
if rerank_model_func is not None if rerank_model_func
else None, else None,
# Environment variable status (requested configuration) # Environment variable status (requested configuration)
"summary_language": args.summary_language, "summary_language": args.summary_language,

View file

@ -66,6 +66,11 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict[str, List[str]]: Knowledge graph for label Dict[str, List[str]]: Knowledge graph for label
""" """
try: try:
# Log the label parameter to check for leading spaces
logger.debug(
f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})"
)
return await rag.get_knowledge_graph( return await rag.get_knowledge_graph(
node_label=label, node_label=label,
max_depth=max_depth, max_depth=max_depth,

View file

@ -469,8 +469,8 @@ class OllamaAPI:
"/chat", dependencies=[Depends(combined_auth)], include_in_schema=True "/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
) )
async def chat(raw_request: Request): async def chat(raw_request: Request):
"""Process chat completion requests acting as an Ollama model """Process chat completion requests by acting as an Ollama model.
Routes user queries through LightRAG by selecting query mode based on prefix indicators. Routes user queries through LightRAG by selecting query mode based on query prefix.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
Supports both application/json and application/octet-stream Content-Types. Supports both application/json and application/octet-stream Content-Types.
""" """

View file

@ -153,7 +153,7 @@ def main():
# Timeout configuration prioritizes command line arguments # Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = ( gunicorn_config.timeout = (
global_args.timeout * 2 global_args.timeout + 30
if global_args.timeout is not None if global_args.timeout is not None
else get_env_value( else get_env_value(
"TIMEOUT", DEFAULT_TIMEOUT + 30, int, special_none=True "TIMEOUT", DEFAULT_TIMEOUT + 30, int, special_none=True

View file

@ -201,6 +201,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.port}") ASCIIColors.yellow(f"{args.port}")
ASCIIColors.white(" ├─ Workers: ", end="") ASCIIColors.white(" ├─ Workers: ", end="")
ASCIIColors.yellow(f"{args.workers}") ASCIIColors.yellow(f"{args.workers}")
ASCIIColors.white(" ├─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout}")
ASCIIColors.white(" ├─ CORS Origins: ", end="") ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{args.cors_origins}") ASCIIColors.yellow(f"{args.cors_origins}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="") ASCIIColors.white(" ├─ SSL Enabled: ", end="")
@ -238,14 +240,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.llm_binding_host}") ASCIIColors.yellow(f"{args.llm_binding_host}")
ASCIIColors.white(" ├─ Model: ", end="") ASCIIColors.white(" ├─ Model: ", end="")
ASCIIColors.yellow(f"{args.llm_model}") ASCIIColors.yellow(f"{args.llm_model}")
ASCIIColors.white(" ├─ Temperature: ", end="")
ASCIIColors.yellow(f"{args.temperature}")
ASCIIColors.white(" ├─ Max Async for LLM: ", end="") ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
ASCIIColors.yellow(f"{args.max_async}") ASCIIColors.yellow(f"{args.max_async}")
ASCIIColors.white(" ├─ Max Tokens: ", end="") ASCIIColors.white(" ├─ Summary Context Size: ", end="")
ASCIIColors.yellow(f"{args.max_tokens}") ASCIIColors.yellow(f"{args.summary_context_size}")
ASCIIColors.white(" ├─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
ASCIIColors.white(" ├─ LLM Cache Enabled: ", end="") ASCIIColors.white(" ├─ LLM Cache Enabled: ", end="")
ASCIIColors.yellow(f"{args.enable_llm_cache}") ASCIIColors.yellow(f"{args.enable_llm_cache}")
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="") ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
@ -266,6 +264,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.magenta("\n⚙️ RAG Configuration:") ASCIIColors.magenta("\n⚙️ RAG Configuration:")
ASCIIColors.white(" ├─ Summary Language: ", end="") ASCIIColors.white(" ├─ Summary Language: ", end="")
ASCIIColors.yellow(f"{args.summary_language}") ASCIIColors.yellow(f"{args.summary_language}")
ASCIIColors.white(" ├─ Entity Types: ", end="")
ASCIIColors.yellow(f"{args.entity_types}")
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="") ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
ASCIIColors.yellow(f"{args.max_parallel_insert}") ASCIIColors.yellow(f"{args.max_parallel_insert}")
ASCIIColors.white(" ├─ Chunk Size: ", end="") ASCIIColors.white(" ├─ Chunk Size: ", end="")

View file

@ -22,7 +22,6 @@ from .constants import (
DEFAULT_MAX_RELATION_TOKENS, DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_HISTORY_TURNS, DEFAULT_HISTORY_TURNS,
DEFAULT_ENABLE_RERANK,
DEFAULT_OLLAMA_MODEL_NAME, DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG, DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_OLLAMA_MODEL_SIZE, DEFAULT_OLLAMA_MODEL_SIZE,
@ -143,10 +142,6 @@ class QueryParam:
history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS))) history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS)))
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
# TODO: TODO: Deprecated - ID-based filtering only applies to chunks, not entities or relations, and implemented only in PostgreSQL storage
ids: list[str] | None = None
"""List of doc ids to filter the results."""
model_func: Callable[..., object] | None = None model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query. """Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function. If provided, this will be used instead of the global model function.
@ -158,9 +153,7 @@ class QueryParam:
If proivded, this will be use instead of the default vaulue from prompt template. If proivded, this will be use instead of the default vaulue from prompt template.
""" """
enable_rerank: bool = ( enable_rerank: bool = os.getenv("RERANK_BY_DEFAULT", "true").lower() == "true"
os.getenv("ENABLE_RERANK", str(DEFAULT_ENABLE_RERANK).lower()).lower() == "true"
)
"""Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. """Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued.
Default is True to enable reranking when rerank model is available. Default is True to enable reranking when rerank model is available.
""" """
@ -219,9 +212,16 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.""" """Query the vector storage and retrieve top_k results.
Args:
query: The query string to search for
top_k: Number of top results to return
query_embedding: Optional pre-computed embedding for the query.
If provided, skips embedding computation for better performance.
"""
@abstractmethod @abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:

View file

@ -11,10 +11,29 @@ DEFAULT_WOKERS = 2
DEFAULT_MAX_GRAPH_NODES = 1000 DEFAULT_MAX_GRAPH_NODES = 1000
# Default values for extraction settings # Default values for extraction settings
DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for summaries DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for document processing
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 4
DEFAULT_MAX_GLEANING = 1 DEFAULT_MAX_GLEANING = 1
DEFAULT_SUMMARY_MAX_TOKENS = 10000 # Default maximum token size
# Number of description fragments to trigger LLM summary
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 8
# Max description token size to trigger LLM summary
DEFAULT_SUMMARY_MAX_TOKENS = 1200
# Recommended LLM summary output length in tokens
DEFAULT_SUMMARY_LENGTH_RECOMMENDED = 600
# Maximum token size sent to LLM for summary
DEFAULT_SUMMARY_CONTEXT_SIZE = 12000
# Default entities to extract if ENTITY_TYPES is not specified in .env
DEFAULT_ENTITY_TYPES = [
"Organization",
"Person",
"Location",
"Event",
"Technology",
"Equipment",
"Product",
"Document",
"Category",
]
# Separator for graph fields # Separator for graph fields
GRAPH_FIELD_SEP = "<SEP>" GRAPH_FIELD_SEP = "<SEP>"
@ -32,8 +51,8 @@ DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
DEFAULT_HISTORY_TURNS = 0 DEFAULT_HISTORY_TURNS = 0
# Rerank configuration defaults # Rerank configuration defaults
DEFAULT_ENABLE_RERANK = True
DEFAULT_MIN_RERANK_SCORE = 0.0 DEFAULT_MIN_RERANK_SCORE = 0.0
DEFAULT_RERANK_BINDING = "null"
# File path configuration for vector and graph database(Should not be changed, used in Milvus Schema) # File path configuration for vector and graph database(Should not be changed, used in Milvus Schema)
DEFAULT_MAX_FILE_PATH_LENGTH = 32768 DEFAULT_MAX_FILE_PATH_LENGTH = 32768
@ -49,8 +68,12 @@ DEFAULT_MAX_PARALLEL_INSERT = 2 # Default maximum parallel insert operations
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions
DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations
# Ollama Server Timetout in seconds # Gunicorn worker timeout
DEFAULT_TIMEOUT = 150 DEFAULT_TIMEOUT = 210
# Default llm and embedding timeout
DEFAULT_LLM_TIMEOUT = 180
DEFAULT_EMBEDDING_TIMEOUT = 30
# Logging configuration defaults # Logging configuration defaults
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB

View file

@ -58,3 +58,41 @@ class RateLimitError(APIStatusError):
class APITimeoutError(APIConnectionError): class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None: def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request) super().__init__(message="Request timed out.", request=request)
class StorageNotInitializedError(RuntimeError):
"""Raised when storage operations are attempted before initialization."""
def __init__(self, storage_type: str = "Storage"):
super().__init__(
f"{storage_type} not initialized. Please ensure proper initialization:\n"
f"\n"
f" rag = LightRAG(...)\n"
f" await rag.initialize_storages() # Required\n"
f" \n"
f" from lightrag.kg.shared_storage import initialize_pipeline_status\n"
f" await initialize_pipeline_status() # Required for pipeline operations\n"
f"\n"
f"See: https://github.com/HKUDS/LightRAG#important-initialization-requirements"
)
class PipelineNotInitializedError(KeyError):
"""Raised when pipeline status is accessed before initialization."""
def __init__(self, namespace: str = ""):
msg = (
f"Pipeline namespace '{namespace}' not found. "
f"This usually means pipeline status was not initialized.\n"
f"\n"
f"Please call 'await initialize_pipeline_status()' after initializing storages:\n"
f"\n"
f" from lightrag.kg.shared_storage import initialize_pipeline_status\n"
f" await initialize_pipeline_status()\n"
f"\n"
f"Full initialization sequence:\n"
f" rag = LightRAG(...)\n"
f" await rag.initialize_storages()\n"
f" await initialize_pipeline_status()"
)
super().__init__(msg)

View file

@ -180,16 +180,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Search by a textual query; returns top_k results with their metadata + similarity distance. Search by a textual query; returns top_k results with their metadata + similarity distance.
""" """
embedding = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = np.array([query_embedding], dtype=np.float32)
) # higher priority for query else:
# embedding is shape (1, dim) embedding = await self.embedding_func(
embedding = np.array(embedding, dtype=np.float32) [query], _priority=5
) # higher priority for query
# embedding is shape (1, dim)
embedding = np.array(embedding, dtype=np.float32)
faiss.normalize_L2(embedding) # we do in-place normalization faiss.normalize_L2(embedding) # we do in-place normalization
# Perform the similarity search # Perform the similarity search

View file

@ -13,6 +13,7 @@ from lightrag.utils import (
write_json, write_json,
get_pinyin_sort_key, get_pinyin_sort_key,
) )
from lightrag.exceptions import StorageNotInitializedError
from .shared_storage import ( from .shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_lock, get_storage_lock,
@ -65,11 +66,15 @@ class JsonDocStatusStorage(DocStatusStorage):
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)""" """Return keys that should be processed (not in storage or not successfully processed)"""
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock: async with self._storage_lock:
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] result: list[dict[str, Any]] = []
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock: async with self._storage_lock:
for id in ids: for id in ids:
data = self._data.get(id, None) data = self._data.get(id, None)
@ -80,6 +85,8 @@ class JsonDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus} counts = {status.value: 0 for status in DocStatus}
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock: async with self._storage_lock:
for doc in self._data.values(): for doc in self._data.values():
counts[doc["status"]] += 1 counts[doc["status"]] += 1
@ -166,6 +173,8 @@ class JsonDocStatusStorage(DocStatusStorage):
logger.debug( logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}" f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
) )
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock: async with self._storage_lock:
# Ensure chunks_list field exists for new documents # Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items(): for doc_id, doc_data in data.items():

View file

@ -10,6 +10,7 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from lightrag.exceptions import StorageNotInitializedError
from .shared_storage import ( from .shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_lock, get_storage_lock,
@ -154,6 +155,8 @@ class JsonKVStorage(BaseKVStorage):
logger.debug( logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}" f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
) )
if self._storage_lock is None:
raise StorageNotInitializedError("JsonKVStorage")
async with self._storage_lock: async with self._storage_lock:
# Add timestamps to data based on whether key exists # Add timestamps to data based on whether key exists
for k, v in data.items(): for k, v in data.items():

View file

@ -1047,14 +1047,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return results return results
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# Ensure collection is loaded before querying # Ensure collection is loaded before querying
self._ensure_collection_loaded() self._ensure_collection_loaded()
embedding = await self.embedding_func( # Use provided embedding or compute it
[query], _priority=5 if query_embedding is not None:
) # higher priority for query embedding = [query_embedding] # Milvus expects a list of embeddings
else:
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Include all meta_fields (created_at is now always included) # Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields) output_fields = list(self.meta_fields)

View file

@ -280,6 +280,30 @@ class MongoDocStatusStorage(DocStatusStorage):
db: AsyncDatabase = field(default=None) db: AsyncDatabase = field(default=None)
_data: AsyncCollection = field(default=None) _data: AsyncCollection = field(default=None)
def _prepare_doc_status_data(self, doc: dict[str, Any]) -> dict[str, Any]:
"""Normalize and migrate a raw Mongo document to DocProcessingStatus-compatible dict."""
# Make a copy of the data to avoid modifying the original
data = doc.copy()
# Remove deprecated content field if it exists
data.pop("content", None)
# Remove MongoDB _id field if it exists
data.pop("_id", None)
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
# Ensure new fields exist with default values
if "metadata" not in data:
data["metadata"] = {}
if "error_msg" not in data:
data["error_msg"] = None
# Backward compatibility: migrate legacy 'error' field to 'error_msg'
if "error" in data:
if "error_msg" not in data or data["error_msg"] in (None, ""):
data["error_msg"] = data.pop("error")
else:
data.pop("error", None)
return data
def __init__(self, namespace, global_config, embedding_func, workspace=None): def __init__(self, namespace, global_config, embedding_func, workspace=None):
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
@ -389,20 +413,7 @@ class MongoDocStatusStorage(DocStatusStorage):
processed_result = {} processed_result = {}
for doc in result: for doc in result:
try: try:
# Make a copy of the data to avoid modifying the original data = self._prepare_doc_status_data(doc)
data = doc.copy()
# Remove deprecated content field if it exists
data.pop("content", None)
# Remove MongoDB _id field if it exists
data.pop("_id", None)
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
# Ensure new fields exist with default values
if "metadata" not in data:
data["metadata"] = {}
if "error_msg" not in data:
data["error_msg"] = None
processed_result[doc["_id"]] = DocProcessingStatus(**data) processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e: except KeyError as e:
logger.error( logger.error(
@ -420,20 +431,7 @@ class MongoDocStatusStorage(DocStatusStorage):
processed_result = {} processed_result = {}
for doc in result: for doc in result:
try: try:
# Make a copy of the data to avoid modifying the original data = self._prepare_doc_status_data(doc)
data = doc.copy()
# Remove deprecated content field if it exists
data.pop("content", None)
# Remove MongoDB _id field if it exists
data.pop("_id", None)
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
# Ensure new fields exist with default values
if "metadata" not in data:
data["metadata"] = {}
if "error_msg" not in data:
data["error_msg"] = None
processed_result[doc["_id"]] = DocProcessingStatus(**data) processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e: except KeyError as e:
logger.error( logger.error(
@ -661,20 +659,7 @@ class MongoDocStatusStorage(DocStatusStorage):
try: try:
doc_id = doc["_id"] doc_id = doc["_id"]
# Make a copy of the data to avoid modifying the original data = self._prepare_doc_status_data(doc)
data = doc.copy()
# Remove deprecated content field if it exists
data.pop("content", None)
# Remove MongoDB _id field if it exists
data.pop("_id", None)
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
# Ensure new fields exist with default values
if "metadata" not in data:
data["metadata"] = {}
if "error_msg" not in data:
data["error_msg"] = None
doc_status = DocProcessingStatus(**data) doc_status = DocProcessingStatus(**data)
documents.append((doc_id, doc_status)) documents.append((doc_id, doc_status))
@ -1825,16 +1810,22 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data return list_data
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search.""" """Queries the vector database using Atlas Vector Search."""
# Generate the embedding if query_embedding is not None:
embedding = await self.embedding_func( # Convert numpy array to list if needed for MongoDB compatibility
[query], _priority=5 if hasattr(query_embedding, "tolist"):
) # higher priority for query query_vector = query_embedding.tolist()
else:
# Convert numpy array to a list to ensure compatibility with MongoDB query_vector = list(query_embedding)
query_vector = embedding[0].tolist() else:
# Generate the embedding
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Convert numpy array to a list to ensure compatibility with MongoDB
query_vector = embedding[0].tolist()
# Define the aggregation pipeline with the converted query vector # Define the aggregation pipeline with the converted query vector
pipeline = [ pipeline = [

View file

@ -137,13 +137,17 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid improve cocurrent # Use provided embedding or compute it
embedding = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = query_embedding
) # higher priority for query else:
embedding = embedding[0] # Execute embedding outside of lock to avoid improve cocurrent
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding[0]
client = await self._get_client() client = await self._get_client()
results = client.query( results = client.query(

View file

@ -2005,18 +2005,21 @@ class PGVectorStorage(BaseVectorStorage):
#################### query method ############### #################### query method ###############
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
embeddings = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = query_embedding
) # higher priority for query else:
embedding = embeddings[0] embeddings = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
# Use parameterized document IDs (None means search across all documents)
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
params = { params = {
"workspace": self.workspace, "workspace": self.workspace,
"doc_ids": ids,
"closer_than_threshold": 1 - self.cosine_better_than_threshold, "closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k, "top_k": top_k,
} }
@ -4582,85 +4585,34 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time
""", """,
"relationships": """ "relationships": """
WITH relevant_chunks AS (SELECT id as chunk_id SELECT r.source_id AS src_id,
FROM LIGHTRAG_VDB_CHUNKS r.target_id AS tgt_id,
WHERE $2 EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_RELATION r FROM LIGHTRAG_VDB_RELATION r
WHERE r.workspace = $1 WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50) LIMIT $3;
)
SELECT c.src_id,
c.tgt_id,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"entities": """ "entities": """
WITH relevant_chunks AS (SELECT id as chunk_id SELECT e.entity_name,
FROM LIGHTRAG_VDB_CHUNKS EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_ENTITY e FROM LIGHTRAG_VDB_ENTITY e
WHERE e.workspace = $1 WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50) LIMIT $3;
)
SELECT c.entity_name,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"chunks": """ "chunks": """
WITH relevant_chunks AS (SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
)
, rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_CHUNKS
WHERE workspace = $1
ORDER BY content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50)
)
SELECT c.id, SELECT c.id,
c.content, c.content,
c.file_path, c.file_path,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
FROM cand c FROM LIGHTRAG_VDB_CHUNKS c
JOIN rc ON TRUE WHERE c.workspace = $1
WHERE c.dist < $3 AND c.content_vector <=> '[{embedding_string}]'::vector < $2
AND c.id = ANY (rc.chunk_arr) ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
ORDER BY c.dist, c.id LIMIT $3;
LIMIT $4;
""", """,
# DROP tables # DROP tables
"drop_specifiy_table_workspace": """ "drop_specifiy_table_workspace": """

View file

@ -200,14 +200,19 @@ class QdrantVectorDBStorage(BaseVectorStorage):
return results return results
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
embedding = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = query_embedding
) # higher priority for query else:
embedding_result = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding_result[0]
results = self._client.search( results = self._client.search(
collection_name=self.final_namespace, collection_name=self.final_namespace,
query_vector=embedding[0], query_vector=embedding,
limit=top_k, limit=top_k,
with_payload=True, with_payload=True,
score_threshold=self.cosine_better_than_threshold, score_threshold=self.cosine_better_than_threshold,

View file

@ -8,6 +8,8 @@ import time
import logging import logging
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic from typing import Any, Dict, List, Optional, Union, TypeVar, Generic
from lightrag.exceptions import PipelineNotInitializedError
# Define a direct print function for critical logs that must be visible in all processes # Define a direct print function for critical logs that must be visible in all processes
def direct_log(message, enable_output: bool = False, level: str = "DEBUG"): def direct_log(message, enable_output: bool = False, level: str = "DEBUG"):
@ -1057,7 +1059,7 @@ async def initialize_pipeline_status():
Initialize pipeline namespace with default values. Initialize pipeline namespace with default values.
This function is called during FASTAPI lifespan for each worker. This function is called during FASTAPI lifespan for each worker.
""" """
pipeline_namespace = await get_namespace_data("pipeline_status") pipeline_namespace = await get_namespace_data("pipeline_status", first_init=True)
async with get_internal_lock(): async with get_internal_lock():
# Check if already initialized by checking for required fields # Check if already initialized by checking for required fields
@ -1192,8 +1194,16 @@ async def try_initialize_namespace(namespace: str) -> bool:
return False return False
async def get_namespace_data(namespace: str) -> Dict[str, Any]: async def get_namespace_data(
"""get the shared data reference for specific namespace""" namespace: str, first_init: bool = False
) -> Dict[str, Any]:
"""get the shared data reference for specific namespace
Args:
namespace: The namespace to retrieve
allow_create: If True, allows creation of the namespace if it doesn't exist.
Used internally by initialize_pipeline_status().
"""
if _shared_dicts is None: if _shared_dicts is None:
direct_log( direct_log(
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}", f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
@ -1203,6 +1213,13 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
async with get_internal_lock(): async with get_internal_lock():
if namespace not in _shared_dicts: if namespace not in _shared_dicts:
# Special handling for pipeline_status namespace
if namespace == "pipeline_status" and not first_init:
# Check if pipeline_status should have been initialized but wasn't
# This helps users understand they need to call initialize_pipeline_status()
raise PipelineNotInitializedError(namespace)
# For other namespaces or when allow_create=True, create them dynamically
if _is_multiprocess and _manager is not None: if _is_multiprocess and _manager is not None:
_shared_dicts[namespace] = _manager.dict() _shared_dicts[namespace] = _manager.dict()
else: else:

View file

@ -9,7 +9,6 @@ import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
from pathlib import Path
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -35,9 +34,15 @@ from lightrag.constants import (
DEFAULT_KG_CHUNK_PICK_METHOD, DEFAULT_KG_CHUNK_PICK_METHOD,
DEFAULT_MIN_RERANK_SCORE, DEFAULT_MIN_RERANK_SCORE,
DEFAULT_SUMMARY_MAX_TOKENS, DEFAULT_SUMMARY_MAX_TOKENS,
DEFAULT_SUMMARY_CONTEXT_SIZE,
DEFAULT_SUMMARY_LENGTH_RECOMMENDED,
DEFAULT_MAX_ASYNC, DEFAULT_MAX_ASYNC,
DEFAULT_MAX_PARALLEL_INSERT, DEFAULT_MAX_PARALLEL_INSERT,
DEFAULT_MAX_GRAPH_NODES, DEFAULT_MAX_GRAPH_NODES,
DEFAULT_ENTITY_TYPES,
DEFAULT_SUMMARY_LANGUAGE,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
) )
from lightrag.utils import get_env_value from lightrag.utils import get_env_value
@ -278,6 +283,10 @@ class LightRAG:
- use_llm_check: If True, validates cached embeddings using an LLM. - use_llm_check: If True, validates cached embeddings using an LLM.
""" """
default_embedding_timeout: int = field(
default=int(os.getenv("EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT))
)
# LLM Configuration # LLM Configuration
# --- # ---
@ -288,10 +297,22 @@ class LightRAG:
"""Name of the LLM model used for generating responses.""" """Name of the LLM model used for generating responses."""
summary_max_tokens: int = field( summary_max_tokens: int = field(
default=int(os.getenv("MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS)) default=int(os.getenv("SUMMARY_MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS))
)
"""Maximum tokens allowed for entity/relation description."""
summary_context_size: int = field(
default=int(os.getenv("SUMMARY_CONTEXT_SIZE", DEFAULT_SUMMARY_CONTEXT_SIZE))
) )
"""Maximum number of tokens allowed per LLM response.""" """Maximum number of tokens allowed per LLM response."""
summary_length_recommended: int = field(
default=int(
os.getenv("SUMMARY_LENGTH_RECOMMENDED", DEFAULT_SUMMARY_LENGTH_RECOMMENDED)
)
)
"""Recommended length of LLM summary output."""
llm_model_max_async: int = field( llm_model_max_async: int = field(
default=int(os.getenv("MAX_ASYNC", DEFAULT_MAX_ASYNC)) default=int(os.getenv("MAX_ASYNC", DEFAULT_MAX_ASYNC))
) )
@ -300,6 +321,10 @@ class LightRAG:
llm_model_kwargs: dict[str, Any] = field(default_factory=dict) llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function.""" """Additional keyword arguments passed to the LLM model function."""
default_llm_timeout: int = field(
default=int(os.getenv("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT))
)
# Rerank Configuration # Rerank Configuration
# --- # ---
@ -338,7 +363,10 @@ class LightRAG:
addon_params: dict[str, Any] = field( addon_params: dict[str, Any] = field(
default_factory=lambda: { default_factory=lambda: {
"language": get_env_value("SUMMARY_LANGUAGE", "English", str) "language": get_env_value(
"SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE, str
),
"entity_types": get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES, list),
} }
) )
@ -421,6 +449,20 @@ class LightRAG:
if self.ollama_server_infos is None: if self.ollama_server_infos is None:
self.ollama_server_infos = OllamaServerInfos() self.ollama_server_infos = OllamaServerInfos()
# Validate config
if self.force_llm_summary_on_merge < 3:
logger.warning(
f"force_llm_summary_on_merge should be at least 3, got {self.force_llm_summary_on_merge}"
)
if self.summary_context_size > self.max_total_tokens:
logger.warning(
f"summary_context_size({self.summary_context_size}) should no greater than max_total_tokens({self.max_total_tokens})"
)
if self.summary_length_recommended > self.summary_max_tokens:
logger.warning(
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
)
# Fix global_config now # Fix global_config now
global_config = asdict(self) global_config = asdict(self)
@ -429,7 +471,9 @@ class LightRAG:
# Init Embedding # Init Embedding
self.embedding_func = priority_limit_async_func_call( self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout,
queue_name="Embedding func:",
)(self.embedding_func) )(self.embedding_func)
# Initialize all storages # Initialize all storages
@ -522,7 +566,12 @@ class LightRAG:
# Directly use llm_response_cache, don't create a new object # Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)( # Get timeout from LLM model kwargs for dynamic timeout calculation
self.llm_model_func = priority_limit_async_func_call(
self.llm_model_max_async,
llm_timeout=self.default_llm_timeout,
queue_name="LLM func:",
)(
partial( partial(
self.llm_model_func, # type: ignore self.llm_model_func, # type: ignore
hashing_kv=hashing_kv, hashing_kv=hashing_kv,
@ -530,14 +579,6 @@ class LightRAG:
) )
) )
# Init Rerank
if self.rerank_model_func:
logger.info("Rerank model initialized for improved retrieval quality")
else:
logger.warning(
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
)
self._storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
async def initialize_storages(self): async def initialize_storages(self):
@ -2573,117 +2614,111 @@ class LightRAG:
relationships_to_delete = set() relationships_to_delete = set()
relationships_to_rebuild = {} # (src, tgt) -> remaining_chunk_ids relationships_to_rebuild = {} # (src, tgt) -> remaining_chunk_ids
# Use graph database lock to ensure atomic merges and updates try:
# Get affected entities and relations from full_entities and full_relations storage
doc_entities_data = await self.full_entities.get_by_id(doc_id)
doc_relations_data = await self.full_relations.get_by_id(doc_id)
affected_nodes = []
affected_edges = []
# Get entity data from graph storage using entity names from full_entities
if doc_entities_data and "entity_names" in doc_entities_data:
entity_names = doc_entities_data["entity_names"]
# get_nodes_batch returns dict[str, dict], need to convert to list[dict]
nodes_dict = await self.chunk_entity_relation_graph.get_nodes_batch(
entity_names
)
for entity_name in entity_names:
node_data = nodes_dict.get(entity_name)
if node_data:
# Ensure compatibility with existing logic that expects "id" field
if "id" not in node_data:
node_data["id"] = entity_name
affected_nodes.append(node_data)
# Get relation data from graph storage using relation pairs from full_relations
if doc_relations_data and "relation_pairs" in doc_relations_data:
relation_pairs = doc_relations_data["relation_pairs"]
edge_pairs_dicts = [
{"src": pair[0], "tgt": pair[1]} for pair in relation_pairs
]
# get_edges_batch returns dict[tuple[str, str], dict], need to convert to list[dict]
edges_dict = await self.chunk_entity_relation_graph.get_edges_batch(
edge_pairs_dicts
)
for pair in relation_pairs:
src, tgt = pair[0], pair[1]
edge_key = (src, tgt)
edge_data = edges_dict.get(edge_key)
if edge_data:
# Ensure compatibility with existing logic that expects "source" and "target" fields
if "source" not in edge_data:
edge_data["source"] = src
if "target" not in edge_data:
edge_data["target"] = tgt
affected_edges.append(edge_data)
except Exception as e:
logger.error(f"Failed to analyze affected graph elements: {e}")
raise Exception(f"Failed to analyze graph dependencies: {e}") from e
try:
# Process entities
for node_data in affected_nodes:
node_label = node_data.get("entity_id")
if node_label and "source_id" in node_data:
sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
remaining_sources = sources - chunk_ids
if not remaining_sources:
entities_to_delete.add(node_label)
elif remaining_sources != sources:
entities_to_rebuild[node_label] = remaining_sources
async with pipeline_status_lock:
log_message = f"Found {len(entities_to_rebuild)} affected entities"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process relationships
for edge_data in affected_edges:
src = edge_data.get("source")
tgt = edge_data.get("target")
if src and tgt and "source_id" in edge_data:
edge_tuple = tuple(sorted((src, tgt)))
if (
edge_tuple in relationships_to_delete
or edge_tuple in relationships_to_rebuild
):
continue
sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
remaining_sources = sources - chunk_ids
if not remaining_sources:
relationships_to_delete.add(edge_tuple)
elif remaining_sources != sources:
relationships_to_rebuild[edge_tuple] = remaining_sources
async with pipeline_status_lock:
log_message = (
f"Found {len(relationships_to_rebuild)} affected relations"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to process graph analysis results: {e}")
raise Exception(f"Failed to process graph dependencies: {e}") from e
# Use graph database lock to prevent dirty read
graph_db_lock = get_graph_db_lock(enable_logging=False) graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock: async with graph_db_lock:
try:
# Get affected entities and relations from full_entities and full_relations storage
doc_entities_data = await self.full_entities.get_by_id(doc_id)
doc_relations_data = await self.full_relations.get_by_id(doc_id)
affected_nodes = []
affected_edges = []
# Get entity data from graph storage using entity names from full_entities
if doc_entities_data and "entity_names" in doc_entities_data:
entity_names = doc_entities_data["entity_names"]
# get_nodes_batch returns dict[str, dict], need to convert to list[dict]
nodes_dict = (
await self.chunk_entity_relation_graph.get_nodes_batch(
entity_names
)
)
for entity_name in entity_names:
node_data = nodes_dict.get(entity_name)
if node_data:
# Ensure compatibility with existing logic that expects "id" field
if "id" not in node_data:
node_data["id"] = entity_name
affected_nodes.append(node_data)
# Get relation data from graph storage using relation pairs from full_relations
if doc_relations_data and "relation_pairs" in doc_relations_data:
relation_pairs = doc_relations_data["relation_pairs"]
edge_pairs_dicts = [
{"src": pair[0], "tgt": pair[1]} for pair in relation_pairs
]
# get_edges_batch returns dict[tuple[str, str], dict], need to convert to list[dict]
edges_dict = (
await self.chunk_entity_relation_graph.get_edges_batch(
edge_pairs_dicts
)
)
for pair in relation_pairs:
src, tgt = pair[0], pair[1]
edge_key = (src, tgt)
edge_data = edges_dict.get(edge_key)
if edge_data:
# Ensure compatibility with existing logic that expects "source" and "target" fields
if "source" not in edge_data:
edge_data["source"] = src
if "target" not in edge_data:
edge_data["target"] = tgt
affected_edges.append(edge_data)
except Exception as e:
logger.error(f"Failed to analyze affected graph elements: {e}")
raise Exception(f"Failed to analyze graph dependencies: {e}") from e
try:
# Process entities
for node_data in affected_nodes:
node_label = node_data.get("entity_id")
if node_label and "source_id" in node_data:
sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
remaining_sources = sources - chunk_ids
if not remaining_sources:
entities_to_delete.add(node_label)
elif remaining_sources != sources:
entities_to_rebuild[node_label] = remaining_sources
async with pipeline_status_lock:
log_message = (
f"Found {len(entities_to_rebuild)} affected entities"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process relationships
for edge_data in affected_edges:
src = edge_data.get("source")
tgt = edge_data.get("target")
if src and tgt and "source_id" in edge_data:
edge_tuple = tuple(sorted((src, tgt)))
if (
edge_tuple in relationships_to_delete
or edge_tuple in relationships_to_rebuild
):
continue
sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
remaining_sources = sources - chunk_ids
if not remaining_sources:
relationships_to_delete.add(edge_tuple)
elif remaining_sources != sources:
relationships_to_rebuild[edge_tuple] = remaining_sources
async with pipeline_status_lock:
log_message = (
f"Found {len(relationships_to_rebuild)} affected relations"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to process graph analysis results: {e}")
raise Exception(f"Failed to process graph dependencies: {e}") from e
# 5. Delete chunks from storage # 5. Delete chunks from storage
if chunk_ids: if chunk_ids:
try: try:
@ -2754,27 +2789,28 @@ class LightRAG:
logger.error(f"Failed to delete relationships: {e}") logger.error(f"Failed to delete relationships: {e}")
raise Exception(f"Failed to delete relationships: {e}") from e raise Exception(f"Failed to delete relationships: {e}") from e
# 8. Rebuild entities and relationships from remaining chunks # Persist changes to graph database before releasing graph database lock
if entities_to_rebuild or relationships_to_rebuild: await self._insert_done()
try:
await _rebuild_knowledge_from_chunks(
entities_to_rebuild=entities_to_rebuild,
relationships_to_rebuild=relationships_to_rebuild,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entities_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
text_chunks_storage=self.text_chunks,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
)
except Exception as e: # 8. Rebuild entities and relationships from remaining chunks
logger.error(f"Failed to rebuild knowledge from chunks: {e}") if entities_to_rebuild or relationships_to_rebuild:
raise Exception( try:
f"Failed to rebuild knowledge graph: {e}" await _rebuild_knowledge_from_chunks(
) from e entities_to_rebuild=entities_to_rebuild,
relationships_to_rebuild=relationships_to_rebuild,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entities_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
text_chunks_storage=self.text_chunks,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
)
except Exception as e:
logger.error(f"Failed to rebuild knowledge from chunks: {e}")
raise Exception(f"Failed to rebuild knowledge graph: {e}") from e
# 9. Delete from full_entities and full_relations storage # 9. Delete from full_entities and full_relations storage
try: try:

View file

@ -36,7 +36,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
llm_instance = OpenAI( llm_instance = OpenAI(
model="gpt-4", model="gpt-4",
api_key="your-openai-key", api_key="your-openai-key",
temperature=0.7,
) )
kwargs['llm_instance'] = llm_instance kwargs['llm_instance'] = llm_instance
@ -91,7 +90,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name" model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
api_base=settings.LITELLM_URL, api_base=settings.LITELLM_URL,
api_key=settings.LITELLM_KEY, api_key=settings.LITELLM_KEY,
temperature=0.7,
) )
kwargs['llm_instance'] = llm_instance kwargs['llm_instance'] = llm_instance

View file

@ -77,14 +77,23 @@ async def anthropic_complete_if_cache(
if not VERBOSE_DEBUG and logger.level == logging.DEBUG: if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("anthropic").setLevel(logging.INFO) logging.getLogger("anthropic").setLevel(logging.INFO)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
timeout = kwargs.pop("timeout", None)
anthropic_async_client = ( anthropic_async_client = (
AsyncAnthropic(default_headers=default_headers, api_key=api_key) AsyncAnthropic(
default_headers=default_headers, api_key=api_key, timeout=timeout
)
if base_url is None if base_url is None
else AsyncAnthropic( else AsyncAnthropic(
base_url=base_url, default_headers=default_headers, api_key=api_key base_url=base_url,
default_headers=default_headers,
api_key=api_key,
timeout=timeout,
) )
) )
kwargs.pop("hashing_kv", None)
messages: list[dict[str, Any]] = [] messages: list[dict[str, Any]] = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})

View file

@ -59,13 +59,17 @@ async def azure_openai_complete_if_cache(
or os.getenv("OPENAI_API_VERSION") or os.getenv("OPENAI_API_VERSION")
) )
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
timeout = kwargs.pop("timeout", None)
openai_async_client = AsyncAzureOpenAI( openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url, azure_endpoint=base_url,
azure_deployment=deployment, azure_deployment=deployment,
api_key=api_key, api_key=api_key,
api_version=api_version, api_version=api_version,
timeout=timeout,
) )
kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})

View file

@ -99,7 +99,7 @@ class BindingOptions:
group = parser.add_argument_group(f"{cls._binding_name} binding options") group = parser.add_argument_group(f"{cls._binding_name} binding options")
for arg_item in cls.args_env_name_type_value(): for arg_item in cls.args_env_name_type_value():
# Handle JSON parsing for list types # Handle JSON parsing for list types
if arg_item["type"] == List[str]: if arg_item["type"] is List[str]:
def json_list_parser(value): def json_list_parser(value):
try: try:
@ -126,6 +126,34 @@ class BindingOptions:
default=env_value, default=env_value,
help=arg_item["help"], help=arg_item["help"],
) )
# Handle JSON parsing for dict types
elif arg_item["type"] is dict:
def json_dict_parser(value):
try:
parsed = json.loads(value)
if not isinstance(parsed, dict):
raise argparse.ArgumentTypeError(
f"Expected JSON object, got {type(parsed).__name__}"
)
return parsed
except json.JSONDecodeError as e:
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
# Get environment variable with JSON parsing
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
if env_value is not argparse.SUPPRESS:
try:
env_value = json_dict_parser(env_value)
except argparse.ArgumentTypeError:
env_value = argparse.SUPPRESS
group.add_argument(
f"--{arg_item['argname']}",
type=json_dict_parser,
default=env_value,
help=arg_item["help"],
)
else: else:
group.add_argument( group.add_argument(
f"--{arg_item['argname']}", f"--{arg_item['argname']}",
@ -234,8 +262,8 @@ class BindingOptions:
if arg_item["help"]: if arg_item["help"]:
sample_stream.write(f"# {arg_item['help']}\n") sample_stream.write(f"# {arg_item['help']}\n")
# Handle JSON formatting for list types # Handle JSON formatting for list and dict types
if arg_item["type"] == List[str]: if arg_item["type"] is List[str] or arg_item["type"] is dict:
default_value = json.dumps(arg_item["default"]) default_value = json.dumps(arg_item["default"])
else: else:
default_value = arg_item["default"] default_value = arg_item["default"]
@ -431,6 +459,8 @@ class OpenAILLMOptions(BindingOptions):
stop: List[str] = field(default_factory=list) # Stop sequences stop: List[str] = field(default_factory=list) # Stop sequences
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0 to 2.0) temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0 to 2.0)
top_p: float = 1.0 # Nucleus sampling parameter (0.0 to 1.0) top_p: float = 1.0 # Nucleus sampling parameter (0.0 to 1.0)
max_tokens: int = None # Maximum number of tokens to generate(deprecated, use max_completion_tokens instead)
extra_body: dict = None # Extra body parameters for OpenRouter of vLLM
# Help descriptions # Help descriptions
_help: ClassVar[dict[str, str]] = { _help: ClassVar[dict[str, str]] = {
@ -443,6 +473,8 @@ class OpenAILLMOptions(BindingOptions):
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')', "stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
"temperature": "Controls randomness (0.0-2.0, higher = more creative)", "temperature": "Controls randomness (0.0-2.0, higher = more creative)",
"top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)", "top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)",
"max_tokens": "Maximum number of tokens to generate (deprecated, use max_completion_tokens instead)",
"extra_body": 'Extra body parameters for OpenRouter of vLLM (JSON dict, e.g., \'"reasoning": {"reasoning": {"enabled": false}}\')',
} }
@ -493,6 +525,8 @@ if __name__ == "__main__":
"1000", "1000",
"--openai-llm-stop", "--openai-llm-stop",
'["</s>", "\\n\\n"]', '["</s>", "\\n\\n"]',
"--openai-llm-reasoning",
'{"effort": "high", "max_tokens": 2000, "exclude": false, "enabled": true}',
] ]
) )
print("Final args for LLM and Embedding:") print("Final args for LLM and Embedding:")
@ -518,5 +552,100 @@ if __name__ == "__main__":
print("\nOpenAI LLM options instance:") print("\nOpenAI LLM options instance:")
print(openai_options.asdict()) print(openai_options.asdict())
# Test creating OpenAI options instance with reasoning parameter
openai_options_with_reasoning = OpenAILLMOptions(
temperature=0.9,
max_completion_tokens=2000,
reasoning={
"effort": "medium",
"max_tokens": 1500,
"exclude": True,
"enabled": True,
},
)
print("\nOpenAI LLM options instance with reasoning:")
print(openai_options_with_reasoning.asdict())
# Test dict parsing functionality
print("\n" + "=" * 50)
print("TESTING DICT PARSING FUNCTIONALITY")
print("=" * 50)
# Test valid JSON dict parsing
test_parser = ArgumentParser(description="Test dict parsing")
OpenAILLMOptions.add_args(test_parser)
try:
test_args = test_parser.parse_args(
["--openai-llm-reasoning", '{"effort": "low", "max_tokens": 1000}']
)
print("✓ Valid JSON dict parsing successful:")
print(
f" Parsed reasoning: {OpenAILLMOptions.options_dict(test_args)['reasoning']}"
)
except Exception as e:
print(f"✗ Valid JSON dict parsing failed: {e}")
# Test invalid JSON dict parsing
try:
test_args = test_parser.parse_args(
[
"--openai-llm-reasoning",
'{"effort": "low", "max_tokens": 1000', # Missing closing brace
]
)
print("✗ Invalid JSON should have failed but didn't")
except SystemExit:
print("✓ Invalid JSON dict parsing correctly rejected")
except Exception as e:
print(f"✓ Invalid JSON dict parsing correctly rejected: {e}")
# Test non-dict JSON parsing
try:
test_args = test_parser.parse_args(
[
"--openai-llm-reasoning",
'["not", "a", "dict"]', # Array instead of dict
]
)
print("✗ Non-dict JSON should have failed but didn't")
except SystemExit:
print("✓ Non-dict JSON parsing correctly rejected")
except Exception as e:
print(f"✓ Non-dict JSON parsing correctly rejected: {e}")
print("\n" + "=" * 50)
print("TESTING ENVIRONMENT VARIABLE SUPPORT")
print("=" * 50)
# Test environment variable support for dict
import os
os.environ["OPENAI_LLM_REASONING"] = (
'{"effort": "high", "max_tokens": 3000, "exclude": false}'
)
env_parser = ArgumentParser(description="Test env var dict parsing")
OpenAILLMOptions.add_args(env_parser)
try:
env_args = env_parser.parse_args(
[]
) # No command line args, should use env var
reasoning_from_env = OpenAILLMOptions.options_dict(env_args).get(
"reasoning"
)
if reasoning_from_env:
print("✓ Environment variable dict parsing successful:")
print(f" Parsed reasoning from env: {reasoning_from_env}")
else:
print("✗ Environment variable dict parsing failed: No reasoning found")
except Exception as e:
print(f"✗ Environment variable dict parsing failed: {e}")
finally:
# Clean up environment variable
if "OPENAI_LLM_REASONING" in os.environ:
del os.environ["OPENAI_LLM_REASONING"]
else: else:
print(BindingOptions.generate_dot_env_sample()) print(BindingOptions.generate_dot_env_sample())

View file

@ -59,7 +59,7 @@ async def lollms_model_if_cache(
"personality": kwargs.get("personality", -1), "personality": kwargs.get("personality", -1),
"n_predict": kwargs.get("n_predict", None), "n_predict": kwargs.get("n_predict", None),
"stream": stream, "stream": stream,
"temperature": kwargs.get("temperature", 0.8), "temperature": kwargs.get("temperature", 1.0),
"top_k": kwargs.get("top_k", 50), "top_k": kwargs.get("top_k", 50),
"top_p": kwargs.get("top_p", 0.95), "top_p": kwargs.get("top_p", 0.95),
"repeat_penalty": kwargs.get("repeat_penalty", 0.8), "repeat_penalty": kwargs.get("repeat_penalty", 0.8),

View file

@ -51,6 +51,8 @@ async def _ollama_model_if_cache(
# kwargs.pop("response_format", None) # allow json # kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
if timeout == 0:
timeout = None
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers = { headers = {

View file

@ -149,18 +149,20 @@ async def openai_complete_if_cache(
if not VERBOSE_DEBUG and logger.level == logging.DEBUG: if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("openai").setLevel(logging.INFO) logging.getLogger("openai").setLevel(logging.INFO)
# Remove special kwargs that shouldn't be passed to OpenAI
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
# Extract client configuration options # Extract client configuration options
client_configs = kwargs.pop("openai_client_configs", {}) client_configs = kwargs.pop("openai_client_configs", {})
# Create the OpenAI client # Create the OpenAI client
openai_async_client = create_openai_async_client( openai_async_client = create_openai_async_client(
api_key=api_key, base_url=base_url, client_configs=client_configs api_key=api_key,
base_url=base_url,
client_configs=client_configs,
) )
# Remove special kwargs that shouldn't be passed to OpenAI
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
# Prepare messages # Prepare messages
messages: list[dict[str, Any]] = [] messages: list[dict[str, Any]] = []
if system_prompt: if system_prompt:

File diff suppressed because it is too large Load diff

View file

@ -4,60 +4,57 @@ from typing import Any
PROMPTS: dict[str, Any] = {} PROMPTS: dict[str, Any] = {}
PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
PROMPTS["DEFAULT_USER_PROMPT"] = "n/a" PROMPTS["DEFAULT_USER_PROMPT"] = "n/a"
PROMPTS["entity_extraction"] = """---Goal--- PROMPTS["entity_extraction"] = """---Task---
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. Given a text document and a list of entity types, identify all entities of those types and all relationships among the identified entities.
Use {language} as output language.
---Steps--- ---Instructions---
1. Identify all entities. For each identified entity, extract the following information: 1. Recognizing definitively conceptualized entities in text. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name - entity_name: Name of the entity, use same language as input text. If English, capitalized the name
- entity_type: One of the following types: [{entity_types}] - entity_type: Categorize the entity using the provided `Entity_types` list. If a suitable category cannot be determined, classify it as "Other".
- entity_description: Provide a comprehensive description of the entity's attributes and activities *based solely on the information present in the input text*. **Do not infer or hallucinate information not explicitly stated.** If the text provides insufficient information to create a comprehensive description, state "Description not available in text." - entity_description: Provide a comprehensive description of the entity's attributes and activities based on the information present in the input text. To ensure clarity and precision, all descriptions must replace pronouns and referential terms (e.g., "this document," "our company," "I," "you," "he/she") with the specific nouns they represent.
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>) 2. Format each entity as: ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
3. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are directly and clearly related based on the text. Unsubstantiated relationships must be excluded from the output.
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information: For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1 - source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1 - target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other - relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity - relationship_description: Explain the nature of the relationship between the source and target entities, providing a clear rationale for their connection
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details 4. Format each relationship as: ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_description>)
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>) 5. Use `{tuple_delimiter}` as field delimiter. Use `{record_delimiter}` as the entity or relation list delimiter.
6. Return identified entities and relationships in {language}.
7. Output `{completion_delimiter}` when all the entities and relationships are extracted.
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. ---Quality Guidelines---
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>) - Only extract entities that are clearly defined and meaningful in the context
- Avoid over-interpretation; stick to what is explicitly stated in the text
- For all output content, explicitly name the subject or object rather than using pronouns
- Include specific numerical data in entity name when relevant
- Ensure entity names are consistent throughout the extraction
4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
######################
---Examples--- ---Examples---
######################
{examples} {examples}
############################# ---Input---
---Real Data---
######################
Entity_types: [{entity_types}] Entity_types: [{entity_types}]
Text: Text:
```
{input_text} {input_text}
###################### ```
Output:"""
---Output---
"""
PROMPTS["entity_extraction_examples"] = [ PROMPTS["entity_extraction_examples"] = [
"""Example 1: """[Example 1]
Entity_types: [person, technology, mission, organization, location] ---Input---
Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text: Text:
``` ```
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
@ -69,22 +66,24 @@ The underlying dismissal earlier seemed to falter, replaced by a glimpse of relu
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
``` ```
Output: ---Output---
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} (entity{tuple_delimiter}Alex{tuple_delimiter}person{tuple_delimiter}Alex is a character who experiences frustration and is observant of the dynamics among other characters.){record_delimiter}
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} (entity{tuple_delimiter}Taylor{tuple_delimiter}person{tuple_delimiter}Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective.){record_delimiter}
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} (entity{tuple_delimiter}Jordan{tuple_delimiter}person{tuple_delimiter}Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device.){record_delimiter}
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} (entity{tuple_delimiter}Cruz{tuple_delimiter}person{tuple_delimiter}Cruz is associated with a vision of control and order, influencing the dynamics among other characters.){record_delimiter}
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} (entity{tuple_delimiter}The Device{tuple_delimiter}equiment{tuple_delimiter}The Device is central to the story, with potential game-changing implications, and is revered by Taylor.){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter} (relationship{tuple_delimiter}Alex{tuple_delimiter}Taylor{tuple_delimiter}power dynamics, observation{tuple_delimiter}Alex observes Taylor's authoritarian behavior and notes changes in Taylor's attitude toward the device.){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter} (relationship{tuple_delimiter}Alex{tuple_delimiter}Jordan{tuple_delimiter}shared goals, rebellion{tuple_delimiter}Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision.){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter} (relationship{tuple_delimiter}Taylor{tuple_delimiter}Jordan{tuple_delimiter}conflict resolution, mutual respect{tuple_delimiter}Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce.){record_delimiter}
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter} (relationship{tuple_delimiter}Jordan{tuple_delimiter}Cruz{tuple_delimiter}ideological conflict, rebellion{tuple_delimiter}Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order.){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} (relationship{tuple_delimiter}Taylor{tuple_delimiter}The Device{tuple_delimiter}reverence, technological significance{tuple_delimiter}Taylor shows reverence towards the device, indicating its importance and potential impact.){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} {completion_delimiter}
#############################""",
"""Example 2:
Entity_types: [company, index, commodity, market_trend, economic_policy, biological] """,
"""[Example 2]
---Input---
Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text: Text:
``` ```
Stock markets faced a sharp downturn today as tech giants saw significant declines, with the Global Tech Index dropping by 3.4% in midday trading. Analysts attribute the selloff to investor concerns over rising interest rates and regulatory uncertainty. Stock markets faced a sharp downturn today as tech giants saw significant declines, with the Global Tech Index dropping by 3.4% in midday trading. Analysts attribute the selloff to investor concerns over rising interest rates and regulatory uncertainty.
@ -96,101 +95,113 @@ Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1
Financial experts are closely watching the Federal Reserve's next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability. Financial experts are closely watching the Federal Reserve's next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability.
``` ```
Output: ---Output---
("entity"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"index"{tuple_delimiter}"The Global Tech Index tracks the performance of major technology stocks and experienced a 3.4% decline today."){record_delimiter} (entity{tuple_delimiter}Global Tech Index{tuple_delimiter}category{tuple_delimiter}The Global Tech Index tracks the performance of major technology stocks and experienced a 3.4% decline today.){record_delimiter}
("entity"{tuple_delimiter}"Nexon Technologies"{tuple_delimiter}"company"{tuple_delimiter}"Nexon Technologies is a tech company that saw its stock decline by 7.8% after disappointing earnings."){record_delimiter} (entity{tuple_delimiter}Nexon Technologies{tuple_delimiter}organization{tuple_delimiter}Nexon Technologies is a tech company that saw its stock decline by 7.8% after disappointing earnings.){record_delimiter}
("entity"{tuple_delimiter}"Omega Energy"{tuple_delimiter}"company"{tuple_delimiter}"Omega Energy is an energy company that gained 2.1% in stock value due to rising oil prices."){record_delimiter} (entity{tuple_delimiter}Omega Energy{tuple_delimiter}organization{tuple_delimiter}Omega Energy is an energy company that gained 2.1% in stock value due to rising oil prices.){record_delimiter}
("entity"{tuple_delimiter}"Gold Futures"{tuple_delimiter}"commodity"{tuple_delimiter}"Gold futures rose by 1.5%, indicating increased investor interest in safe-haven assets."){record_delimiter} (entity{tuple_delimiter}Gold Futures{tuple_delimiter}product{tuple_delimiter}Gold futures rose by 1.5%, indicating increased investor interest in safe-haven assets.){record_delimiter}
("entity"{tuple_delimiter}"Crude Oil"{tuple_delimiter}"commodity"{tuple_delimiter}"Crude oil prices rose to $87.60 per barrel due to supply constraints and strong demand."){record_delimiter} (entity{tuple_delimiter}Crude Oil{tuple_delimiter}product{tuple_delimiter}Crude oil prices rose to $87.60 per barrel due to supply constraints and strong demand.){record_delimiter}
("entity"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"market_trend"{tuple_delimiter}"Market selloff refers to the significant decline in stock values due to investor concerns over interest rates and regulations."){record_delimiter} (entity{tuple_delimiter}Market Selloff{tuple_delimiter}category{tuple_delimiter}Market selloff refers to the significant decline in stock values due to investor concerns over interest rates and regulations.){record_delimiter}
("entity"{tuple_delimiter}"Federal Reserve Policy Announcement"{tuple_delimiter}"economic_policy"{tuple_delimiter}"The Federal Reserve's upcoming policy announcement is expected to impact investor confidence and market stability."){record_delimiter} (entity{tuple_delimiter}Federal Reserve Policy Announcement{tuple_delimiter}category{tuple_delimiter}The Federal Reserve's upcoming policy announcement is expected to impact investor confidence and market stability.){record_delimiter}
("relationship"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"The decline in the Global Tech Index is part of the broader market selloff driven by investor concerns."{tuple_delimiter}"market performance, investor sentiment"{tuple_delimiter}9){record_delimiter} (entity{tuple_delimiter}3.4% Decline{tuple_delimiter}category{tuple_delimiter}The Global Tech Index experienced a 3.4% decline in midday trading.){record_delimiter}
("relationship"{tuple_delimiter}"Nexon Technologies"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"Nexon Technologies' stock decline contributed to the overall drop in the Global Tech Index."{tuple_delimiter}"company impact, index movement"{tuple_delimiter}8){record_delimiter} (relationship{tuple_delimiter}Global Tech Index{tuple_delimiter}Market Selloff{tuple_delimiter}market performance, investor sentiment{tuple_delimiter}The decline in the Global Tech Index is part of the broader market selloff driven by investor concerns.){record_delimiter}
("relationship"{tuple_delimiter}"Gold Futures"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"Gold prices rose as investors sought safe-haven assets during the market selloff."{tuple_delimiter}"market reaction, safe-haven investment"{tuple_delimiter}10){record_delimiter} (relationship{tuple_delimiter}Nexon Technologies{tuple_delimiter}Global Tech Index{tuple_delimiter}company impact, index movement{tuple_delimiter}Nexon Technologies' stock decline contributed to the overall drop in the Global Tech Index.){record_delimiter}
("relationship"{tuple_delimiter}"Federal Reserve Policy Announcement"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"Speculation over Federal Reserve policy changes contributed to market volatility and investor selloff."{tuple_delimiter}"interest rate impact, financial regulation"{tuple_delimiter}7){record_delimiter} (relationship{tuple_delimiter}Gold Futures{tuple_delimiter}Market Selloff{tuple_delimiter}market reaction, safe-haven investment{tuple_delimiter}Gold prices rose as investors sought safe-haven assets during the market selloff.){record_delimiter}
("content_keywords"{tuple_delimiter}"market downturn, investor sentiment, commodities, Federal Reserve, stock performance"){completion_delimiter} (relationship{tuple_delimiter}Federal Reserve Policy Announcement{tuple_delimiter}Market Selloff{tuple_delimiter}interest rate impact, financial regulation{tuple_delimiter}Speculation over Federal Reserve policy changes contributed to market volatility and investor selloff.){record_delimiter}
#############################""", {completion_delimiter}
"""Example 3:
Entity_types: [economic_policy, athlete, event, location, record, organization, equipment] """,
"""[Example 3]
---Input---
Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text: Text:
``` ```
At the World Athletics Championship in Tokyo, Noah Carter broke the 100m sprint record using cutting-edge carbon-fiber spikes. At the World Athletics Championship in Tokyo, Noah Carter broke the 100m sprint record using cutting-edge carbon-fiber spikes.
``` ```
Output: ---Output---
("entity"{tuple_delimiter}"World Athletics Championship"{tuple_delimiter}"event"{tuple_delimiter}"The World Athletics Championship is a global sports competition featuring top athletes in track and field."){record_delimiter} (entity{tuple_delimiter}World Athletics Championship{tuple_delimiter}event{tuple_delimiter}The World Athletics Championship is a global sports competition featuring top athletes in track and field.){record_delimiter}
("entity"{tuple_delimiter}"Tokyo"{tuple_delimiter}"location"{tuple_delimiter}"Tokyo is the host city of the World Athletics Championship."){record_delimiter} (entity{tuple_delimiter}Tokyo{tuple_delimiter}location{tuple_delimiter}Tokyo is the host city of the World Athletics Championship.){record_delimiter}
("entity"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"athlete"{tuple_delimiter}"Noah Carter is a sprinter who set a new record in the 100m sprint at the World Athletics Championship."){record_delimiter} (entity{tuple_delimiter}Noah Carter{tuple_delimiter}person{tuple_delimiter}Noah Carter is a sprinter who set a new record in the 100m sprint at the World Athletics Championship.){record_delimiter}
("entity"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"record"{tuple_delimiter}"The 100m sprint record is a benchmark in athletics, recently broken by Noah Carter."){record_delimiter} (entity{tuple_delimiter}100m Sprint Record{tuple_delimiter}category{tuple_delimiter}The 100m sprint record is a benchmark in athletics, recently broken by Noah Carter.){record_delimiter}
("entity"{tuple_delimiter}"Carbon-Fiber Spikes"{tuple_delimiter}"equipment"{tuple_delimiter}"Carbon-fiber spikes are advanced sprinting shoes that provide enhanced speed and traction."){record_delimiter} (entity{tuple_delimiter}Carbon-Fiber Spikes{tuple_delimiter}equipment{tuple_delimiter}Carbon-fiber spikes are advanced sprinting shoes that provide enhanced speed and traction.){record_delimiter}
("entity"{tuple_delimiter}"World Athletics Federation"{tuple_delimiter}"organization"{tuple_delimiter}"The World Athletics Federation is the governing body overseeing the World Athletics Championship and record validations."){record_delimiter} (entity{tuple_delimiter}World Athletics Federation{tuple_delimiter}organization{tuple_delimiter}The World Athletics Federation is the governing body overseeing the World Athletics Championship and record validations.){record_delimiter}
("relationship"{tuple_delimiter}"World Athletics Championship"{tuple_delimiter}"Tokyo"{tuple_delimiter}"The World Athletics Championship is being hosted in Tokyo."{tuple_delimiter}"event location, international competition"{tuple_delimiter}8){record_delimiter} (relationship{tuple_delimiter}World Athletics Championship{tuple_delimiter}Tokyo{tuple_delimiter}event location, international competition{tuple_delimiter}The World Athletics Championship is being hosted in Tokyo.){record_delimiter}
("relationship"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"Noah Carter set a new 100m sprint record at the championship."{tuple_delimiter}"athlete achievement, record-breaking"{tuple_delimiter}10){record_delimiter} (relationship{tuple_delimiter}Noah Carter{tuple_delimiter}100m Sprint Record{tuple_delimiter}athlete achievement, record-breaking{tuple_delimiter}Noah Carter set a new 100m sprint record at the championship.){record_delimiter}
("relationship"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"Carbon-Fiber Spikes"{tuple_delimiter}"Noah Carter used carbon-fiber spikes to enhance performance during the race."{tuple_delimiter}"athletic equipment, performance boost"{tuple_delimiter}7){record_delimiter} (relationship{tuple_delimiter}Noah Carter{tuple_delimiter}Carbon-Fiber Spikes{tuple_delimiter}athletic equipment, performance boost{tuple_delimiter}Noah Carter used carbon-fiber spikes to enhance performance during the race.){record_delimiter}
("relationship"{tuple_delimiter}"World Athletics Federation"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"The World Athletics Federation is responsible for validating and recognizing new sprint records."{tuple_delimiter}"sports regulation, record certification"{tuple_delimiter}9){record_delimiter} (relationship{tuple_delimiter}Noah Carter{tuple_delimiter}World Athletics Championship{tuple_delimiter}athlete participation, competition{tuple_delimiter}Noah Carter is competing at the World Athletics Championship.){record_delimiter}
("content_keywords"{tuple_delimiter}"athletics, sprinting, record-breaking, sports technology, competition"){completion_delimiter} {completion_delimiter}
#############################""",
]
PROMPTS[ """,
"summarize_entity_descriptions" """[Example 4]
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.
Use {language} as output language.
####### ---Input---
---Data--- Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Entities: {entity_name} Text:
Description List: {description_list} ```
####### 在北京举行的人工智能大会上腾讯公司的首席技术官张伟发布了最新的大语言模型"腾讯智言"该模型在自然语言处理方面取得了重大突破
Output: ```
"""
PROMPTS["entity_continue_extraction"] = """
MANY entities and relationships were missed in the last extraction. Please find only the missing entities and relationships from previous text.
---Remember Steps---
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name
- entity_type: One of the following types: [{entity_types}]
- entity_description: Provide a comprehensive description of the entity's attributes and activities *based solely on the information present in the input text*. **Do not infer or hallucinate information not explicitly stated.** If the text provides insufficient information to create a comprehensive description, state "Description not available in text."
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
---Output--- ---Output---
(entity{tuple_delimiter}人工智能大会{tuple_delimiter}event{tuple_delimiter}人工智能大会是在北京举行的技术会议专注于人工智能领域的最新发展){record_delimiter}
(entity{tuple_delimiter}北京{tuple_delimiter}location{tuple_delimiter}北京是人工智能大会的举办城市){record_delimiter}
(entity{tuple_delimiter}腾讯公司{tuple_delimiter}organization{tuple_delimiter}腾讯公司是参与人工智能大会的科技企业发布了新的语言模型产品){record_delimiter}
(entity{tuple_delimiter}张伟{tuple_delimiter}person{tuple_delimiter}张伟是腾讯公司的首席技术官在大会上发布了新产品){record_delimiter}
(entity{tuple_delimiter}腾讯智言{tuple_delimiter}product{tuple_delimiter}腾讯智言是腾讯公司发布的大语言模型产品在自然语言处理方面有重大突破){record_delimiter}
(entity{tuple_delimiter}自然语言处理技术{tuple_delimiter}technology{tuple_delimiter}自然语言处理技术是腾讯智言模型取得重大突破的技术领域){record_delimiter}
(relationship{tuple_delimiter}人工智能大会{tuple_delimiter}北京{tuple_delimiter}会议地点, 举办关系{tuple_delimiter}人工智能大会在北京举行){record_delimiter}
(relationship{tuple_delimiter}张伟{tuple_delimiter}腾讯公司{tuple_delimiter}雇佣关系, 高管职位{tuple_delimiter}张伟担任腾讯公司的首席技术官){record_delimiter}
(relationship{tuple_delimiter}张伟{tuple_delimiter}腾讯智言{tuple_delimiter}产品发布, 技术展示{tuple_delimiter}张伟在大会上发布了腾讯智言大语言模型){record_delimiter}
(relationship{tuple_delimiter}腾讯智言{tuple_delimiter}自然语言处理技术{tuple_delimiter}技术应用, 突破创新{tuple_delimiter}腾讯智言在自然语言处理技术方面取得了重大突破){record_delimiter}
{completion_delimiter}
Add new entities and relations below using the same format, and do not include entities and relations that have been previously extracted. :\n """,
""".strip() ]
PROMPTS["summarize_entity_descriptions"] = """---Role---
You are a Knowledge Graph Specialist responsible for data curation and synthesis.
---Task---
Your task is to synthesize a list of descriptions of a given entity or relation into a single, comprehensive, and cohesive summary.
---Instructions---
1. **Comprehensiveness:** The summary must integrate key information from all provided descriptions. Do not omit important facts.
2. **Context:** The summary must explicitly mention the name of the entity or relation for full context.
3. **Conflict:** In case of conflicting or inconsistent descriptions, determine if they originate from multiple, distinct entities or relationships that share the same name. If so, summarize each entity or relationship separately and then consolidate all summaries.
4. **Style:** The output must be written from an objective, third-person perspective.
5. **Length:** Maintain depth and completeness while ensuring the summary's length not exceed {summary_length} tokens.
6. **Language:** The entire output must be written in {language}.
---Data---
{description_type} Name: {description_name}
Description List:
{description_list}
---Output---
"""
PROMPTS["entity_continue_extraction"] = """---Task---
Identify any missed entities or relationships in the last extraction task.
---Instructions---
1. Output the entities and realtionships in the same format as previous extraction task.
2. Do not include entities and relations that have been previously extracted.
3. If the entity doesn't clearly fit in any of`Entity_types` provided, classify it as "Other".
4. Return identified entities and relationships in {language}.
5. Output `{completion_delimiter}` when all the entities and relationships are extracted.
---Output---
"""
# TODO: Deprecated
PROMPTS["entity_if_loop_extraction"] = """ PROMPTS["entity_if_loop_extraction"] = """
---Goal---' ---Goal---'
It appears some entities may have still been missed. Check if it appears some entities may have still been missed. Output "Yes" if so, otherwise "No".
---Output--- ---Output---
Output:"""
Answer ONLY by `YES` OR `NO` if there are still entities that need to be added.
""".strip()
PROMPTS["fail_response"] = ( PROMPTS["fail_response"] = (
"Sorry, I'm not able to provide an answer to that question.[no-context]" "Sorry, I'm not able to provide an answer to that question.[no-context]"
@ -211,7 +222,7 @@ Generate a concise response based on Knowledge Base and follow Response Rules, c
---Knowledge Graph and Document Chunks--- ---Knowledge Graph and Document Chunks---
{context_data} {context_data}
---RESPONSE GUIDELINES--- ---Response Guidelines---
**1. Content & Adherence:** **1. Content & Adherence:**
- Strictly adhere to the provided context from the Knowledge Base. Do not invent, assume, or include any information not present in the source data. - Strictly adhere to the provided context from the Knowledge Base. Do not invent, assume, or include any information not present in the source data.
- If the answer cannot be found in the provided context, state that you do not have enough information to answer. - If the answer cannot be found in the provided context, state that you do not have enough information to answer.
@ -233,8 +244,8 @@ Generate a concise response based on Knowledge Base and follow Response Rules, c
---USER CONTEXT--- ---USER CONTEXT---
- Additional user prompt: {user_prompt} - Additional user prompt: {user_prompt}
---Response---
Response:""" """
PROMPTS["keywords_extraction"] = """---Role--- PROMPTS["keywords_extraction"] = """---Role---
You are an expert keyword extractor, specializing in analyzing user queries for a Retrieval-Augmented Generation (RAG) system. Your purpose is to identify both high-level and low-level keywords in the user's query that will be used for effective document retrieval. You are an expert keyword extractor, specializing in analyzing user queries for a Retrieval-Augmented Generation (RAG) system. Your purpose is to identify both high-level and low-level keywords in the user's query that will be used for effective document retrieval.
@ -257,7 +268,7 @@ Given a user query, your task is to extract two distinct types of keywords:
User Query: {query} User Query: {query}
---Output--- ---Output---
""" Output:"""
PROMPTS["keywords_extraction_examples"] = [ PROMPTS["keywords_extraction_examples"] = [
"""Example 1: """Example 1:
@ -327,5 +338,5 @@ Generate a concise response based on Document Chunks and follow Response Rules,
---USER CONTEXT--- ---USER CONTEXT---
- Additional user prompt: {user_prompt} - Additional user prompt: {user_prompt}
---Response---
Response:""" Output:"""

View file

@ -2,270 +2,199 @@ from __future__ import annotations
import os import os
import aiohttp import aiohttp
from typing import Callable, Any, List, Dict, Optional from typing import Any, List, Dict, Optional
from pydantic import BaseModel, Field from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from .utils import logger from .utils import logger
from dotenv import load_dotenv
class RerankModel(BaseModel): # use the .env that is inside the current folder
""" # allows to use different .env file for each lightrag instance
Wrapper for rerank functions that can be used with LightRAG. # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
Example usage:
```python
from lightrag.rerank import RerankModel, jina_rerank
# Create rerank model
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True)
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_n=top_n or 10,
**kwargs
)
rag = LightRAG(
rerank_model_func=my_rerank_func,
# ... other configurations
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Enable rerank for this query
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_n: Optional[int] = None,
**extra_kwargs,
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query, documents=documents, top_n=top_n, **kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_n: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_n, **kwargs)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(aiohttp.ClientError)
| retry_if_exception_type(aiohttp.ClientResponseError)
),
)
async def generic_rerank_api( async def generic_rerank_api(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[str],
model: str, model: str,
base_url: str, base_url: str,
api_key: str, api_key: Optional[str],
top_n: Optional[int] = None, top_n: Optional[int] = None,
**kwargs, return_documents: Optional[bool] = None,
extra_body: Optional[Dict[str, Any]] = None,
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Generic rerank function that works with Jina/Cohere compatible APIs. Generic rerank API call for Jina/Cohere/Aliyun models.
Args: Args:
query: The search query query: The search query
documents: List of documents to rerank documents: List of strings to rerank
model: Model identifier model: Model name to use
base_url: API endpoint URL base_url: API endpoint URL
api_key: API authentication key api_key: API key for authentication
top_n: Number of top results to return top_n: Number of top results to return
**kwargs: Additional API-specific parameters return_documents: Whether to return document text (Jina only)
extra_body: Additional body parameters
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
Returns: Returns:
List of reranked documents with relevance scores List of dictionary of ["index": int, "relevance_score": float]
""" """
if not api_key: if not base_url:
logger.warning("No API key provided for rerank service") raise ValueError("Base URL is required")
return documents
if not documents: headers = {"Content-Type": "application/json"}
return documents if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Prepare documents for reranking - handle both text and dict formats # Build request payload based on request format
prepared_docs = [] if request_format == "aliyun":
for doc in documents: # Aliyun format: nested input/parameters structure
if isinstance(doc, dict): payload = {
# Use 'content' field if available, otherwise use 'text' or convert to string "model": model,
text = doc.get("content") or doc.get("text") or str(doc) "input": {
else: "query": query,
text = str(doc) "documents": documents,
prepared_docs.append(text) },
"parameters": {},
}
# Prepare request # Add optional parameters to parameters object
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} if top_n is not None:
payload["parameters"]["top_n"] = top_n
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} if return_documents is not None:
payload["parameters"]["return_documents"] = return_documents
if top_n is not None: # Add extra parameters to parameters object
data["top_n"] = min(top_n, len(prepared_docs)) if extra_body:
payload["parameters"].update(extra_body)
else:
# Standard format for Jina/Cohere
payload = {
"model": model,
"query": query,
"documents": documents,
}
try: # Add optional parameters
async with aiohttp.ClientSession() as session: if top_n is not None:
async with session.post(base_url, headers=headers, json=data) as response: payload["top_n"] = top_n
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
result = await response.json() # Only Jina API supports return_documents parameter
if return_documents is not None:
payload["return_documents"] = return_documents
# Extract reranked results # Add extra parameters
if "results" in result: if extra_body:
# Standard format: results contain index and relevance_score payload.update(extra_body)
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item[
"relevance_score"
]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
except Exception as e: logger.debug(
logger.error(f"Error during reranking: {e}") f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_n: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_n: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
) )
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
content_type = response.headers.get("content-type", "").lower()
is_html_error = (
error_text.strip().startswith("<!DOCTYPE html>")
or "text/html" in content_type
)
if is_html_error:
if response.status == 502:
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
elif response.status == 503:
clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
elif response.status == 504:
clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
else:
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
else:
clean_error = error_text
logger.error(f"Rerank API error {response.status}: {clean_error}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Rerank API error: {clean_error}",
)
response_json = await response.json()
if response_format == "aliyun":
# Aliyun format: {"output": {"results": [...]}}
results = response_json.get("output", {}).get("results", [])
if not isinstance(results, list):
logger.warning(
f"Expected 'output.results' to be list, got {type(results)}: {results}"
)
results = []
elif response_format == "standard":
# Standard format: {"results": [...]}
results = response_json.get("results", [])
if not isinstance(results, list):
logger.warning(
f"Expected 'results' to be list, got {type(results)}: {results}"
)
results = []
else:
raise ValueError(f"Unsupported response format: {response_format}")
if not results:
logger.warning("Rerank API returned empty results")
return []
# Standardize return format
return [
{"index": result["index"], "relevance_score": result["relevance_score"]}
for result in results
]
async def cohere_rerank( async def cohere_rerank(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[str],
model: str = "rerank-english-v2.0",
top_n: Optional[int] = None, top_n: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None, api_key: Optional[str] = None,
**kwargs, model: str = "rerank-v3.5",
base_url: str = "https://api.cohere.com/v2/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Rerank documents using Cohere API. Rerank documents using Cohere API.
Args: Args:
query: The search query query: The search query
documents: List of documents to rerank documents: List of strings to rerank
model: Cohere rerank model name
top_n: Number of top results to return top_n: Number of top results to return
base_url: Cohere API endpoint api_key: API key
api_key: Cohere API key model: rerank model name
**kwargs: Additional parameters base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns: Returns:
List of reranked documents with relevance scores List of dictionary of ["index": int, "relevance_score": float]
""" """
if api_key is None: if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api( return await generic_rerank_api(
query=query, query=query,
@ -274,24 +203,39 @@ async def cohere_rerank(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
top_n=top_n, top_n=top_n,
**kwargs, return_documents=None, # Cohere doesn't support this parameter
extra_body=extra_body,
response_format="standard",
) )
# Convenience function for custom API endpoints async def jina_rerank(
async def custom_rerank(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[str],
model: str,
base_url: str,
api_key: str,
top_n: Optional[int] = None, top_n: Optional[int] = None,
**kwargs, api_key: Optional[str] = None,
model: str = "jina-reranker-v2-base-multilingual",
base_url: str = "https://api.jina.ai/v1/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Rerank documents using a custom API endpoint. Rerank documents using Jina AI API.
This is useful for self-hosted or custom rerank services.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
""" """
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api( return await generic_rerank_api(
query=query, query=query,
documents=documents, documents=documents,
@ -299,26 +243,112 @@ async def custom_rerank(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
top_n=top_n, top_n=top_n,
**kwargs, return_documents=False,
extra_body=extra_body,
response_format="standard",
) )
async def ali_rerank(
query: str,
documents: List[str],
top_n: Optional[int] = None,
api_key: Optional[str] = None,
model: str = "gte-rerank-v2",
base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Aliyun DashScope API.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: Aliyun API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
return_documents=False, # Aliyun doesn't need this parameter
extra_body=extra_body,
response_format="aliyun",
request_format="aliyun",
)
"""Please run this test as a module:
python -m lightrag.rerank
"""
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
async def main(): async def main():
# Example usage # Example usage - documents should be strings, not dictionaries
docs = [ docs = [
{"content": "The capital of France is Paris."}, "The capital of France is Paris.",
{"content": "Tokyo is the capital of Japan."}, "Tokyo is the capital of Japan.",
{"content": "London is the capital of England."}, "London is the capital of England.",
] ]
query = "What is the capital of France?" query = "What is the capital of France?"
result = await jina_rerank( # Test Jina rerank
query=query, documents=docs, top_n=2, api_key="your-api-key-here" try:
) print("=== Jina Rerank ===")
print(result) result = await jina_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Jina Error: {e}")
# Test Cohere rerank
try:
print("\n=== Cohere Rerank ===")
result = await cohere_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Cohere Error: {e}")
# Test Aliyun rerank
try:
print("\n=== Aliyun Rerank ===")
result = await ali_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Aliyun Error: {e}")
asyncio.run(main()) asyncio.run(main())

View file

@ -0,0 +1,180 @@
#!/usr/bin/env python3
"""
Diagnostic tool to check LightRAG initialization status.
This tool helps developers verify that their LightRAG instance is properly
initialized before use, preventing common initialization errors.
Usage:
python -m lightrag.tools.check_initialization
"""
import asyncio
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from lightrag import LightRAG
from lightrag.base import StoragesStatus
async def check_lightrag_setup(rag_instance: LightRAG, verbose: bool = False) -> bool:
"""
Check if a LightRAG instance is properly initialized.
Args:
rag_instance: The LightRAG instance to check
verbose: If True, print detailed diagnostic information
Returns:
True if properly initialized, False otherwise
"""
issues = []
warnings = []
print("🔍 Checking LightRAG initialization status...\n")
# Check storage initialization status
if not hasattr(rag_instance, "_storages_status"):
issues.append("LightRAG instance missing _storages_status attribute")
elif rag_instance._storages_status != StoragesStatus.INITIALIZED:
issues.append(
f"Storages not initialized (status: {rag_instance._storages_status.name})"
)
else:
print("✅ Storage status: INITIALIZED")
# Check individual storage components
storage_components = [
("full_docs", "Document storage"),
("text_chunks", "Text chunks storage"),
("entities_vdb", "Entity vector database"),
("relationships_vdb", "Relationship vector database"),
("chunks_vdb", "Chunks vector database"),
("doc_status", "Document status tracker"),
("llm_response_cache", "LLM response cache"),
("full_entities", "Entity storage"),
("full_relations", "Relation storage"),
("chunk_entity_relation_graph", "Graph storage"),
]
if verbose:
print("\n📦 Storage Components:")
for component, description in storage_components:
if not hasattr(rag_instance, component):
issues.append(f"Missing storage component: {component} ({description})")
else:
storage = getattr(rag_instance, component)
if storage is None:
warnings.append(f"Storage {component} is None (might be optional)")
elif hasattr(storage, "_storage_lock"):
if storage._storage_lock is None:
issues.append(f"Storage {component} not initialized (lock is None)")
elif verbose:
print(f"{description}: Ready")
elif verbose:
print(f"{description}: Ready")
# Check pipeline status
try:
from lightrag.kg.shared_storage import get_namespace_data
get_namespace_data("pipeline_status")
print("✅ Pipeline status: INITIALIZED")
except KeyError:
issues.append(
"Pipeline status not initialized - call initialize_pipeline_status()"
)
except Exception as e:
issues.append(f"Error checking pipeline status: {str(e)}")
# Print results
print("\n" + "=" * 50)
if issues:
print("❌ Issues found:\n")
for issue in issues:
print(f"{issue}")
print("\n📝 To fix, run this initialization sequence:\n")
print(" await rag.initialize_storages()")
print(" from lightrag.kg.shared_storage import initialize_pipeline_status")
print(" await initialize_pipeline_status()")
print(
"\n📚 Documentation: https://github.com/HKUDS/LightRAG#important-initialization-requirements"
)
if warnings and verbose:
print("\n⚠️ Warnings (might be normal):")
for warning in warnings:
print(f"{warning}")
return False
else:
print("✅ LightRAG is properly initialized and ready to use!")
if warnings and verbose:
print("\n⚠️ Warnings (might be normal):")
for warning in warnings:
print(f"{warning}")
return True
async def demo():
"""Demonstrate the diagnostic tool with a test instance."""
from lightrag.llm.openai import openai_embed, gpt_4o_mini_complete
from lightrag.kg.shared_storage import initialize_pipeline_status
print("=" * 50)
print("LightRAG Initialization Diagnostic Tool")
print("=" * 50)
# Create test instance
rag = LightRAG(
working_dir="./test_diagnostic",
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete,
)
print("\n🔴 BEFORE initialization:\n")
await check_lightrag_setup(rag, verbose=True)
print("\n" + "=" * 50)
print("\n🔄 Initializing...\n")
await rag.initialize_storages()
await initialize_pipeline_status()
print("\n🟢 AFTER initialization:\n")
await check_lightrag_setup(rag, verbose=True)
# Cleanup
import shutil
shutil.rmtree("./test_diagnostic", ignore_errors=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Check LightRAG initialization status")
parser.add_argument(
"--demo", action="store_true", help="Run a demonstration with a test instance"
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Show detailed diagnostic information",
)
args = parser.parse_args()
if args.demo:
asyncio.run(demo())
else:
print("Run with --demo to see the diagnostic tool in action")
print("Or import this module and use check_lightrag_setup() with your instance")

File diff suppressed because it is too large Load diff