From e27031587d0687939616e0f051436c75212d8a55 Mon Sep 17 00:00:00 2001
From: hzywhite <1569582518@qq.com>
Date: Thu, 4 Sep 2025 10:27:38 +0800
Subject: [PATCH] merge
---
lightrag/api/README-zh.md | 101 +-
lightrag/api/README.md | 95 +-
lightrag/api/__init__.py | 2 +-
lightrag/api/config.py | 76 +-
lightrag/api/lightrag_server.py | 344 +++---
lightrag/api/routers/graph_routes.py | 5 +
lightrag/api/routers/ollama_api.py | 4 +-
lightrag/api/run_with_gunicorn.py | 2 +-
lightrag/api/utils_api.py | 12 +-
lightrag/base.py | 20 +-
lightrag/constants.py | 35 +-
lightrag/exceptions.py | 38 +
lightrag/kg/faiss_impl.py | 16 +-
lightrag/kg/json_doc_status_impl.py | 9 +
lightrag/kg/json_kv_impl.py | 3 +
lightrag/kg/milvus_impl.py | 12 +-
lightrag/kg/mongo_impl.py | 91 +-
lightrag/kg/nano_vector_db_impl.py | 16 +-
lightrag/kg/postgres_impl.py | 98 +-
lightrag/kg/qdrant_impl.py | 15 +-
lightrag/kg/shared_storage.py | 23 +-
lightrag/lightrag.py | 320 +++---
lightrag/llm/Readme.md | 2 -
lightrag/llm/anthropic.py | 15 +-
lightrag/llm/azure_openai.py | 6 +-
lightrag/llm/binding_options.py | 135 ++-
lightrag/llm/lollms.py | 2 +-
lightrag/llm/ollama.py | 2 +
lightrag/llm/openai.py | 12 +-
lightrag/operate.py | 1368 +++++++++++++++---------
lightrag/prompt.py | 279 ++---
lightrag/rerank.py | 502 +++++----
lightrag/tools/check_initialization.py | 180 ++++
lightrag/utils.py | 834 +++++++++++----
34 files changed, 2906 insertions(+), 1768 deletions(-)
create mode 100644 lightrag/tools/check_initialization.py
diff --git a/lightrag/api/README-zh.md b/lightrag/api/README-zh.md
index b74e4d12..9f940df1 100644
--- a/lightrag/api/README-zh.md
+++ b/lightrag/api/README-zh.md
@@ -357,7 +357,7 @@ API 服务器可以通过三种方式配置(优先级从高到低):
LightRAG 支持绑定到各种 LLM/嵌入后端:
* ollama
-* openai 和 openai 兼容
+* openai (含openai 兼容)
* azure_openai
* lollms
* aws_bedrock
@@ -372,7 +372,10 @@ lightrag-server --llm-binding ollama --help
lightrag-server --embedding-binding ollama --help
```
+> 请使用openai兼容方式访问OpenRouter或vLLM部署的LLM。可以通过 `OPENAI_LLM_EXTRA_BODY` 环境变量给OpenRouter或vLLM传递额外的参数,实现推理模式的关闭或者其它个性化控制。
+
### 实体提取配置
+
* ENABLE_LLM_CACHE_FOR_EXTRACT:为实体提取启用 LLM 缓存(默认:true)
在测试环境中将 `ENABLE_LLM_CACHE_FOR_EXTRACT` 设置为 true 以减少 LLM 调用成本是很常见的做法。
@@ -386,51 +389,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
* GRAPH_STORAGE:实体关系图
* DOC_STATUS_STORAGE:文档索引状态
-每种存储类型都有几种实现:
+每种存储类型都有多种存储实现方式。LightRAG Server默认的存储实现为内存数据库,数据通过文件持久化保存到WORKING_DIR目录。LightRAG还支持PostgreSQL、MongoDB、FAISS、Milvus、Qdrant、Neo4j、Memgraph和Redis等存储实现方式。详细的存储支持方式请参考根目录下的`README.md`文件中关于存储的相关内容。
-* KV_STORAGE 支持的实现名称
-
-```
-JsonKVStorage JsonFile(默认)
-PGKVStorage Postgres
-RedisKVStorage Redis
-MongoKVStorage MogonDB
-```
-
-* GRAPH_STORAGE 支持的实现名称
-
-```
-NetworkXStorage NetworkX(默认)
-Neo4JStorage Neo4J
-PGGraphStorage PostgreSQL with AGE plugin
-```
-
-> 在测试中Neo4j图形数据库相比PostgreSQL AGE有更好的性能表现。
-
-* VECTOR_STORAGE 支持的实现名称
-
-```
-NanoVectorDBStorage NanoVector(默认)
-PGVectorStorage Postgres
-MilvusVectorDBStorge Milvus
-FaissVectorDBStorage Faiss
-QdrantVectorDBStorage Qdrant
-MongoVectorDBStorage MongoDB
-```
-
-* DOC_STATUS_STORAGE 支持的实现名称
-
-```
-JsonDocStatusStorage JsonFile(默认)
-PGDocStatusStorage Postgres
-MongoDocStatusStorage MongoDB
-```
-
-每一种存储类型的链接配置范例可以在 `env.example` 文件中找到。链接字符串中的数据库实例是需要你预先在数据库服务器上创建好的,LightRAG 仅负责在数据库实例中创建数据表,不负责创建数据库实例。如果使用 Redis 作为存储,记得给 Redis 配置自动持久化数据规则,否则 Redis 服务重启后数据会丢失。如果使用PostgreSQL数据库,推荐使用16.6版本或以上。
-
-### 如何选择存储实现
-
-您可以通过环境变量选择存储实现。在首次启动 API 服务器之前,您可以将以下环境变量设置为特定的存储实现名称:
+您可以通过环境变量选择存储实现。例如,在首次启动 API 服务器之前,您可以将以下环境变量设置为特定的存储实现名称:
```
LIGHTRAG_KV_STORAGE=PGKVStorage
@@ -439,7 +400,7 @@ LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
```
-在向 LightRAG 添加文档后,您不能更改存储实现选择。目前尚不支持从一个存储实现迁移到另一个存储实现。更多信息请阅读示例 env 文件或 config.ini 文件。
+在向 LightRAG 添加文档后,您不能更改存储实现选择。目前尚不支持从一个存储实现迁移到另一个存储实现。更多配置信息请阅读示例 `env.exampl`e文件。
### LightRag API 服务器命令行选项
@@ -450,20 +411,54 @@ LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
| --working-dir | ./rag_storage | RAG 存储的工作目录 |
| --input-dir | ./inputs | 包含输入文档的目录 |
| --max-async | 4 | 最大异步操作数 |
-| --max-tokens | 32768 | 最大 token 大小 |
-| --timeout | 150 | 超时时间(秒)。None 表示无限超时(不推荐) |
| --log-level | INFO | 日志级别(DEBUG、INFO、WARNING、ERROR、CRITICAL) |
| --verbose | - | 详细调试输出(True、False) |
| --key | None | 用于认证的 API 密钥。保护 lightrag 服务器免受未授权访问 |
| --ssl | False | 启用 HTTPS |
| --ssl-certfile | None | SSL 证书文件路径(如果启用 --ssl 则必需) |
| --ssl-keyfile | None | SSL 私钥文件路径(如果启用 --ssl 则必需) |
-| --top-k | 50 | 要检索的 top-k 项目数;在"local"模式下对应实体,在"global"模式下对应关系。 |
-| --cosine-threshold | 0.4 | 节点和关系检索的余弦阈值,与 top-k 一起控制节点和关系的检索。 |
-| --llm-binding | ollama | LLM 绑定类型(lollms、ollama、openai、openai-ollama、azure_openai) |
-| --embedding-binding | ollama | 嵌入绑定类型(lollms、ollama、openai、azure_openai) |
+| --llm-binding | ollama | LLM 绑定类型(lollms、ollama、openai、openai-ollama、azure_openai、aws_bedrock) |
+| --embedding-binding | ollama | 嵌入绑定类型(lollms、ollama、openai、azure_openai、aws_bedrock) |
| auto-scan-at-startup | - | 扫描输入目录中的新文件并开始索引 |
+### Reranking 配置
+
+Reranking 查询召回的块可以显著提高检索质量,它通过基于优化的相关性评分模型对文档重新排序。LightRAG 目前支持以下 rerank 提供商:
+
+- **Cohere / vLLM**:提供与 Cohere AI 的 `v2/rerank` 端点的完整 API 集成。由于 vLLM 提供了与 Cohere 兼容的 reranker API,因此也支持所有通过 vLLM 部署的 reranker 模型。
+- **Jina AI**:提供与所有 Jina rerank 模型的完全实现兼容性。
+- **阿里云**:具有旨在支持阿里云 rerank API 格式的自定义实现。
+
+Rerank 提供商通过 `.env` 文件进行配置。以下是使用 vLLM 本地部署的 rerank 模型的示例配置:
+
+```
+RERANK_BINDING=cohere
+RERANK_MODEL=BAAI/bge-reranker-v2-m3
+RERANK_BINDING_HOST=http://localhost:8000/v1/rerank
+RERANK_BINDING_API_KEY=your_rerank_api_key_here
+```
+
+以下是使用阿里云提供的 Reranker 服务的示例配置:
+
+```
+RERANK_BINDING=aliyun
+RERANK_MODEL=gte-rerank-v2
+RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
+RERANK_BINDING_API_KEY=your_rerank_api_key_here
+```
+
+有关完整的 reranker 配置示例,请参阅 `env.example` 文件。
+
+### 启用 Reranking
+
+可以按查询启用或禁用 Reranking。
+
+`/query` 和 `/query/stream` API 端点包含一个 `enable_rerank` 参数,默认设置为 `true`,用于控制当前查询是否激活 reranking。要将 `enable_rerank` 参数的默认值更改为 `false`,请设置以下环境变量:
+
+```
+RERANK_BY_DEFAULT=False
+```
+
### .env 文件示例
```bash
@@ -478,7 +473,7 @@ SUMMARY_LANGUAGE=Chinese
MAX_PARALLEL_INSERT=2
### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
-TIMEOUT=200
+TIMEOUT=150
MAX_ASYNC=4
LLM_BINDING=openai
diff --git a/lightrag/api/README.md b/lightrag/api/README.md
index da59b38f..0f010234 100644
--- a/lightrag/api/README.md
+++ b/lightrag/api/README.md
@@ -360,7 +360,7 @@ Most of the configurations come with default settings; check out the details in
LightRAG supports binding to various LLM/Embedding backends:
* ollama
-* openai & openai compatible
+* openai (including openai compatible)
* azure_openai
* lollms
* aws_bedrock
@@ -374,6 +374,8 @@ lightrag-server --llm-binding ollama --help
lightrag-server --embedding-binding ollama --help
```
+> Please use OpenAI-compatible method to access LLMs deployed by OpenRouter or vLLM. You can pass additional parameters to OpenRouter or vLLM through the `OPENAI_LLM_EXTRA_BODY` environment variable to disable reasoning mode or achieve other personalized controls.
+
### Entity Extraction Configuration
* ENABLE_LLM_CACHE_FOR_EXTRACT: Enable LLM cache for entity extraction (default: true)
@@ -388,52 +390,9 @@ LightRAG uses 4 types of storage for different purposes:
* GRAPH_STORAGE: entity relation graph
* DOC_STATUS_STORAGE: document indexing status
-Each storage type has several implementations:
+LightRAG Server offers various storage implementations, with the default being an in-memory database that persists data to the WORKING_DIR directory. Additionally, LightRAG supports a wide range of storage solutions including PostgreSQL, MongoDB, FAISS, Milvus, Qdrant, Neo4j, Memgraph, and Redis. For detailed information on supported storage options, please refer to the storage section in the README.md file located in the root directory.
-* KV_STORAGE supported implementations:
-
-```
-JsonKVStorage JsonFile (default)
-PGKVStorage Postgres
-RedisKVStorage Redis
-MongoKVStorage MongoDB
-```
-
-* GRAPH_STORAGE supported implementations:
-
-```
-NetworkXStorage NetworkX (default)
-Neo4JStorage Neo4J
-PGGraphStorage PostgreSQL with AGE plugin
-MemgraphStorage. Memgraph
-```
-
-> Testing has shown that Neo4J delivers superior performance in production environments compared to PostgreSQL with AGE plugin.
-
-* VECTOR_STORAGE supported implementations:
-
-```
-NanoVectorDBStorage NanoVector (default)
-PGVectorStorage Postgres
-MilvusVectorDBStorage Milvus
-FaissVectorDBStorage Faiss
-QdrantVectorDBStorage Qdrant
-MongoVectorDBStorage MongoDB
-```
-
-* DOC_STATUS_STORAGE: supported implementations:
-
-```
-JsonDocStatusStorage JsonFile (default)
-PGDocStatusStorage Postgres
-MongoDocStatusStorage MongoDB
-```
-Example connection configurations for each storage type can be found in the `env.example` file. The database instance in the connection string needs to be created by you on the database server beforehand. LightRAG is only responsible for creating tables within the database instance, not for creating the database instance itself. If using Redis as storage, remember to configure automatic data persistence rules for Redis, otherwise data will be lost after the Redis service restarts. If using PostgreSQL, it is recommended to use version 16.6 or above.
-
-
-### How to Select Storage Implementation
-
-You can select storage implementation by environment variables. You can set the following environment variables to a specific storage implementation name before the first start of the API Server:
+You can select the storage implementation by configuring environment variables. For instance, prior to the initial launch of the API server, you can set the following environment variable to specify your desired storage implementation:
```
LIGHTRAG_KV_STORAGE=PGKVStorage
@@ -453,23 +412,53 @@ You cannot change storage implementation selection after adding documents to Lig
| --working-dir | ./rag_storage | Working directory for RAG storage |
| --input-dir | ./inputs | Directory containing input documents |
| --max-async | 4 | Maximum number of async operations |
-| --max-tokens | 32768 | Maximum token size |
-| --timeout | 150 | Timeout in seconds. None for infinite timeout (not recommended) |
| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| --verbose | - | Verbose debug output (True, False) |
| --key | None | API key for authentication. Protects the LightRAG server against unauthorized access |
| --ssl | False | Enable HTTPS |
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
-| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
-| --cosine-threshold | 0.4 | The cosine threshold for nodes and relation retrieval, works with top-k to control the retrieval of nodes and relations. |
| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai, aws_bedrock) |
| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai, aws_bedrock) |
| --auto-scan-at-startup| - | Scan input directory for new files and start indexing |
-### Additional Ollama Binding Options
+### Reranking Configuration
-When using `--llm-binding ollama` or `--embedding-binding ollama`, additional Ollama-specific configuration options are available. To see all available Ollama binding options, add `--help` to the command line when starting the server. These additional options allow for fine-tuning of Ollama model parameters and connection settings.
+Reranking query-recalled chunks can significantly enhance retrieval quality by re-ordering documents based on an optimized relevance scoring model. LightRAG currently supports the following rerank providers:
+
+- **Cohere / vLLM**: Offers full API integration with Cohere AI's `v2/rerank` endpoint. As vLLM provides a Cohere-compatible reranker API, all reranker models deployed via vLLM are also supported.
+- **Jina AI**: Provides complete implementation compatibility with all Jina rerank models.
+- **Aliyun**: Features a custom implementation designed to support Aliyun's rerank API format.
+
+The rerank provider is configured via the `.env` file. Below is an example configuration for a rerank model deployed locally using vLLM:
+
+```
+RERANK_BINDING=cohere
+RERANK_MODEL=BAAI/bge-reranker-v2-m3
+RERANK_BINDING_HOST=http://localhost:8000/v1/rerank
+RERANK_BINDING_API_KEY=your_rerank_api_key_here
+```
+
+Here is an example configuration for utilizing the Reranker service provided by Aliyun:
+
+```
+RERANK_BINDING=aliyun
+RERANK_MODEL=gte-rerank-v2
+RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
+RERANK_BINDING_API_KEY=your_rerank_api_key_here
+```
+
+For comprehensive reranker configuration examples, please refer to the `env.example` file.
+
+### Enable Reranking
+
+Reranking can be enabled or disabled on a per-query basis.
+
+The `/query` and `/query/stream` API endpoints include an `enable_rerank` parameter, which is set to `true` by default, controlling whether reranking is active for the current query. To change the default value of the `enable_rerank` parameter to `false`, set the following environment variable:
+
+```
+RERANK_BY_DEFAULT=False
+```
### .env Examples
@@ -485,7 +474,7 @@ SUMMARY_LANGUAGE=Chinese
MAX_PARALLEL_INSERT=2
### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
-TIMEOUT=200
+TIMEOUT=150
MAX_ASYNC=4
LLM_BINDING=openai
diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py
index 39c729cd..7cc89012 100644
--- a/lightrag/api/__init__.py
+++ b/lightrag/api/__init__.py
@@ -1 +1 @@
-__api_version__ = "0205"
+__api_version__ = "0213"
diff --git a/lightrag/api/config.py b/lightrag/api/config.py
index 01d0dd75..f17d50f0 100644
--- a/lightrag/api/config.py
+++ b/lightrag/api/config.py
@@ -30,12 +30,15 @@ from lightrag.constants import (
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
DEFAULT_MAX_ASYNC,
DEFAULT_SUMMARY_MAX_TOKENS,
+ DEFAULT_SUMMARY_LENGTH_RECOMMENDED,
+ DEFAULT_SUMMARY_CONTEXT_SIZE,
DEFAULT_SUMMARY_LANGUAGE,
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC,
DEFAULT_EMBEDDING_BATCH_NUM,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
- DEFAULT_TEMPERATURE,
+ DEFAULT_RERANK_BINDING,
+ DEFAULT_ENTITY_TYPES,
)
# use the .env that is inside the current folder
@@ -77,9 +80,7 @@ def parse_args() -> argparse.Namespace:
argparse.Namespace: Parsed arguments
"""
- parser = argparse.ArgumentParser(
- description="LightRAG FastAPI Server with separate working and input directories"
- )
+ parser = argparse.ArgumentParser(description="LightRAG API Server")
# Server configuration
parser.add_argument(
@@ -121,10 +122,26 @@ def parse_args() -> argparse.Namespace:
help=f"Maximum async operations (default: from env or {DEFAULT_MAX_ASYNC})",
)
parser.add_argument(
- "--max-tokens",
+ "--summary-max-tokens",
type=int,
- default=get_env_value("MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS, int),
- help=f"Maximum token size (default: from env or {DEFAULT_SUMMARY_MAX_TOKENS})",
+ default=get_env_value("SUMMARY_MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS, int),
+ help=f"Maximum token size for entity/relation summary(default: from env or {DEFAULT_SUMMARY_MAX_TOKENS})",
+ )
+ parser.add_argument(
+ "--summary-context-size",
+ type=int,
+ default=get_env_value(
+ "SUMMARY_CONTEXT_SIZE", DEFAULT_SUMMARY_CONTEXT_SIZE, int
+ ),
+ help=f"LLM Summary Context size (default: from env or {DEFAULT_SUMMARY_CONTEXT_SIZE})",
+ )
+ parser.add_argument(
+ "--summary-length-recommended",
+ type=int,
+ default=get_env_value(
+ "SUMMARY_LENGTH_RECOMMENDED", DEFAULT_SUMMARY_LENGTH_RECOMMENDED, int
+ ),
+ help=f"LLM Summary Context size (default: from env or {DEFAULT_SUMMARY_LENGTH_RECOMMENDED})",
)
# Logging configuration
@@ -226,6 +243,13 @@ def parse_args() -> argparse.Namespace:
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
help="Embedding binding type (default: from env or ollama)",
)
+ parser.add_argument(
+ "--rerank-binding",
+ type=str,
+ default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING),
+ choices=["null", "cohere", "jina", "aliyun"],
+ help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
+ )
# Conditionally add binding options defined in binding_options module
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
@@ -264,14 +288,6 @@ def parse_args() -> argparse.Namespace:
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
OpenAILLMOptions.add_args(parser)
- # Add global temperature command line argument
- parser.add_argument(
- "--temperature",
- type=float,
- default=get_env_value("TEMPERATURE", DEFAULT_TEMPERATURE, float),
- help="Global temperature setting for LLM (default: from env TEMPERATURE or 0.1)",
- )
-
args = parser.parse_args()
# convert relative path to absolute path
@@ -330,38 +346,13 @@ def parse_args() -> argparse.Namespace:
)
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
- # Handle Ollama LLM temperature with priority cascade when llm-binding is ollama
- if args.llm_binding == "ollama":
- # Priority order (highest to lowest):
- # 1. --ollama-llm-temperature command argument
- # 2. OLLAMA_LLM_TEMPERATURE environment variable
- # 3. --temperature command argument
- # 4. TEMPERATURE environment variable
-
- # Check if --ollama-llm-temperature was explicitly provided in command line
- if "--ollama-llm-temperature" not in sys.argv:
- # Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
- args.ollama_llm_temperature = args.temperature
-
- # Handle OpenAI LLM temperature with priority cascade when llm-binding is openai or azure_openai
- if args.llm_binding in ["openai", "azure_openai"]:
- # Priority order (highest to lowest):
- # 1. --openai-llm-temperature command argument
- # 2. OPENAI_LLM_TEMPERATURE environment variable
- # 3. --temperature command argument
- # 4. TEMPERATURE environment variable
-
- # Check if --openai-llm-temperature was explicitly provided in command line
- if "--openai-llm-temperature" not in sys.argv:
- # Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
- args.openai_llm_temperature = args.temperature
-
# Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
# Add environment variables that were previously read directly
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
args.summary_language = get_env_value("SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE)
+ args.entity_types = get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES, list)
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
# For JWT Auth
@@ -372,9 +363,10 @@ def parse_args() -> argparse.Namespace:
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Rerank model configuration
- args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
+ args.rerank_model = get_env_value("RERANK_MODEL", None)
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
+ # Note: rerank_binding is already set by argparse, no need to override from env
# Min rerank score configuration
args.min_rerank_score = get_env_value(
diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py
index e936e74c..7268208a 100644
--- a/lightrag/api/lightrag_server.py
+++ b/lightrag/api/lightrag_server.py
@@ -2,7 +2,7 @@
LightRAG FastAPI Server
"""
-from fastapi import FastAPI, Depends, HTTPException, status
+from fastapi import FastAPI, Depends, HTTPException
import asyncio
import os
import logging
@@ -11,6 +11,7 @@ import signal
import sys
import uvicorn
import pipmaster as pm
+import inspect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
@@ -38,6 +39,8 @@ from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME,
+ DEFAULT_LLM_TIMEOUT,
+ DEFAULT_EMBEDDING_TIMEOUT,
)
from lightrag.api.routers.document_routes import (
DocumentManager,
@@ -236,25 +239,106 @@ def create_app(args):
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
- if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
- from lightrag.llm.lollms import lollms_model_complete, lollms_embed
- if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
- from lightrag.llm.ollama import ollama_model_complete, ollama_embed
- from lightrag.llm.binding_options import OllamaLLMOptions
- if args.llm_binding == "openai" or args.embedding_binding == "openai":
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
- from lightrag.llm.binding_options import OpenAILLMOptions
- if args.llm_binding == "azure_openai" or args.embedding_binding == "azure_openai":
- from lightrag.llm.azure_openai import (
- azure_openai_complete_if_cache,
- azure_openai_embed,
- )
- if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
- from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed
- if args.embedding_binding == "ollama":
- from lightrag.llm.binding_options import OllamaEmbeddingOptions
- if args.embedding_binding == "jina":
- from lightrag.llm.jina import jina_embed
+
+ def create_llm_model_func(binding: str):
+ """
+ Create LLM model function based on binding type.
+ Uses lazy import to avoid unnecessary dependencies.
+ """
+ try:
+ if binding == "lollms":
+ from lightrag.llm.lollms import lollms_model_complete
+
+ return lollms_model_complete
+ elif binding == "ollama":
+ from lightrag.llm.ollama import ollama_model_complete
+
+ return ollama_model_complete
+ elif binding == "aws_bedrock":
+ return bedrock_model_complete # Already defined locally
+ elif binding == "azure_openai":
+ return azure_openai_model_complete # Already defined locally
+ else: # openai and compatible
+ return openai_alike_model_complete # Already defined locally
+ except ImportError as e:
+ raise Exception(f"Failed to import {binding} LLM binding: {e}")
+
+ def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
+ """
+ Create LLM model kwargs based on binding type.
+ Uses lazy import for binding-specific options.
+ """
+ if binding in ["lollms", "ollama"]:
+ try:
+ from lightrag.llm.binding_options import OllamaLLMOptions
+
+ return {
+ "host": args.llm_binding_host,
+ "timeout": llm_timeout,
+ "options": OllamaLLMOptions.options_dict(args),
+ "api_key": args.llm_binding_api_key,
+ }
+ except ImportError as e:
+ raise Exception(f"Failed to import {binding} options: {e}")
+ return {}
+
+ def create_embedding_function_with_lazy_import(
+ binding, model, host, api_key, dimensions, args
+ ):
+ """
+ Create embedding function with lazy imports for all bindings.
+ Replaces the current create_embedding_function with full lazy import support.
+ """
+
+ async def embedding_function(texts):
+ try:
+ if binding == "lollms":
+ from lightrag.llm.lollms import lollms_embed
+
+ return await lollms_embed(
+ texts, embed_model=model, host=host, api_key=api_key
+ )
+ elif binding == "ollama":
+ from lightrag.llm.binding_options import OllamaEmbeddingOptions
+ from lightrag.llm.ollama import ollama_embed
+
+ ollama_options = OllamaEmbeddingOptions.options_dict(args)
+ return await ollama_embed(
+ texts,
+ embed_model=model,
+ host=host,
+ api_key=api_key,
+ options=ollama_options,
+ )
+ elif binding == "azure_openai":
+ from lightrag.llm.azure_openai import azure_openai_embed
+
+ return await azure_openai_embed(texts, model=model, api_key=api_key)
+ elif binding == "aws_bedrock":
+ from lightrag.llm.bedrock import bedrock_embed
+
+ return await bedrock_embed(texts, model=model)
+ elif binding == "jina":
+ from lightrag.llm.jina import jina_embed
+
+ return await jina_embed(
+ texts, dimensions=dimensions, base_url=host, api_key=api_key
+ )
+ else: # openai and compatible
+ from lightrag.llm.openai import openai_embed
+
+ return await openai_embed(
+ texts, model=model, base_url=host, api_key=api_key
+ )
+ except ImportError as e:
+ raise Exception(f"Failed to import {binding} embedding: {e}")
+
+ return embedding_function
+
+ llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
+ embedding_timeout = get_env_value(
+ "EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
+ )
async def openai_alike_model_complete(
prompt,
@@ -263,18 +347,20 @@ def create_app(args):
keyword_extraction=False,
**kwargs,
) -> str:
+ # Lazy import
+ from lightrag.llm.openai import openai_complete_if_cache
+ from lightrag.llm.binding_options import OpenAILLMOptions
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
- # Use OpenAI LLM options if available, otherwise fallback to global temperature
- if args.llm_binding == "openai":
- openai_options = OpenAILLMOptions.options_dict(args)
- kwargs.update(openai_options)
- else:
- kwargs["temperature"] = args.temperature
+ # Use OpenAI LLM options if available
+ openai_options = OpenAILLMOptions.options_dict(args)
+ kwargs["timeout"] = llm_timeout
+ kwargs.update(openai_options)
return await openai_complete_if_cache(
args.llm_model,
@@ -293,18 +379,20 @@ def create_app(args):
keyword_extraction=False,
**kwargs,
) -> str:
+ # Lazy import
+ from lightrag.llm.azure_openai import azure_openai_complete_if_cache
+ from lightrag.llm.binding_options import OpenAILLMOptions
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
- # Use OpenAI LLM options if available, otherwise fallback to global temperature
- if args.llm_binding == "azure_openai":
- openai_options = OpenAILLMOptions.options_dict(args)
- kwargs.update(openai_options)
- else:
- kwargs["temperature"] = args.temperature
+ # Use OpenAI LLM options
+ openai_options = OpenAILLMOptions.options_dict(args)
+ kwargs["timeout"] = llm_timeout
+ kwargs.update(openai_options)
return await azure_openai_complete_if_cache(
args.llm_model,
@@ -324,6 +412,9 @@ def create_app(args):
keyword_extraction=False,
**kwargs,
) -> str:
+ # Lazy import
+ from lightrag.llm.bedrock import bedrock_complete_if_cache
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
@@ -331,7 +422,7 @@ def create_app(args):
history_messages = []
# Use global temperature for Bedrock
- kwargs["temperature"] = args.temperature
+ kwargs["temperature"] = get_env_value("BEDROCK_LLM_TEMPERATURE", 1.0, float)
return await bedrock_complete_if_cache(
args.llm_model,
@@ -341,86 +432,73 @@ def create_app(args):
**kwargs,
)
+ # Create embedding function with lazy imports
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
- func=lambda texts: (
- lollms_embed(
- texts,
- embed_model=args.embedding_model,
- host=args.embedding_binding_host,
- api_key=args.embedding_binding_api_key,
- )
- if args.embedding_binding == "lollms"
- else (
- ollama_embed(
- texts,
- embed_model=args.embedding_model,
- host=args.embedding_binding_host,
- api_key=args.embedding_binding_api_key,
- options=OllamaEmbeddingOptions.options_dict(args),
- )
- if args.embedding_binding == "ollama"
- else (
- azure_openai_embed(
- texts,
- model=args.embedding_model, # no host is used for openai,
- api_key=args.embedding_binding_api_key,
- )
- if args.embedding_binding == "azure_openai"
- else (
- bedrock_embed(
- texts,
- model=args.embedding_model,
- )
- if args.embedding_binding == "aws_bedrock"
- else (
- jina_embed(
- texts,
- dimensions=args.embedding_dim,
- base_url=args.embedding_binding_host,
- api_key=args.embedding_binding_api_key,
- )
- if args.embedding_binding == "jina"
- else openai_embed(
- texts,
- model=args.embedding_model,
- base_url=args.embedding_binding_host,
- api_key=args.embedding_binding_api_key,
- )
- )
- )
- )
- )
+ func=create_embedding_function_with_lazy_import(
+ binding=args.embedding_binding,
+ model=args.embedding_model,
+ host=args.embedding_binding_host,
+ api_key=args.embedding_binding_api_key,
+ dimensions=args.embedding_dim,
+ args=args, # Pass args object for dynamic option generation
),
)
- # Configure rerank function if model and API are configured
+ # Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None
- if args.rerank_binding_api_key and args.rerank_binding_host:
- from lightrag.rerank import custom_rerank
+ if args.rerank_binding != "null":
+ from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
+
+ # Map rerank binding to corresponding function
+ rerank_functions = {
+ "cohere": cohere_rerank,
+ "jina": jina_rerank,
+ "aliyun": ali_rerank,
+ }
+
+ # Select the appropriate rerank function based on binding
+ selected_rerank_func = rerank_functions.get(args.rerank_binding)
+ if not selected_rerank_func:
+ logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
+ raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
+
+ # Get default values from selected_rerank_func if args values are None
+ if args.rerank_model is None or args.rerank_binding_host is None:
+ sig = inspect.signature(selected_rerank_func)
+
+ # Set default model if args.rerank_model is None
+ if args.rerank_model is None and "model" in sig.parameters:
+ default_model = sig.parameters["model"].default
+ if default_model != inspect.Parameter.empty:
+ args.rerank_model = default_model
+
+ # Set default base_url if args.rerank_binding_host is None
+ if args.rerank_binding_host is None and "base_url" in sig.parameters:
+ default_base_url = sig.parameters["base_url"].default
+ if default_base_url != inspect.Parameter.empty:
+ args.rerank_binding_host = default_base_url
async def server_rerank_func(
- query: str, documents: list, top_n: int = None, **kwargs
+ query: str, documents: list, top_n: int = None, extra_body: dict = None
):
"""Server rerank function with configuration from environment variables"""
- return await custom_rerank(
+ return await selected_rerank_func(
query=query,
documents=documents,
+ top_n=top_n,
+ api_key=args.rerank_binding_api_key,
model=args.rerank_model,
base_url=args.rerank_binding_host,
- api_key=args.rerank_binding_api_key,
- top_n=top_n,
- **kwargs,
+ extra_body=extra_body,
)
rerank_model_func = server_rerank_func
logger.info(
- f"Rerank model configured: {args.rerank_model} (can be enabled per query)"
+ f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
)
else:
- logger.info(
- "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
- )
+ logger.info("Reranking is disabled")
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
@@ -429,38 +507,24 @@ def create_app(args):
name=args.simulated_model_name, tag=args.simulated_model_tag
)
- # Initialize RAG
- if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]:
+ # Initialize RAG with unified configuration
+ try:
rag = LightRAG(
working_dir=args.working_dir,
workspace=args.workspace,
- llm_model_func=(
- lollms_model_complete
- if args.llm_binding == "lollms"
- else (
- ollama_model_complete
- if args.llm_binding == "ollama"
- else bedrock_model_complete
- if args.llm_binding == "aws_bedrock"
- else openai_alike_model_complete
- )
- ),
+ llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
- summary_max_tokens=args.max_tokens,
+ summary_max_tokens=args.summary_max_tokens,
+ summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
- llm_model_kwargs=(
- {
- "host": args.llm_binding_host,
- "timeout": args.timeout,
- "options": OllamaLLMOptions.options_dict(args),
- "api_key": args.llm_binding_api_key,
- }
- if args.llm_binding == "lollms" or args.llm_binding == "ollama"
- else {}
+ llm_model_kwargs=create_llm_model_kwargs(
+ args.llm_binding, args, llm_timeout
),
embedding_func=embedding_func,
+ default_llm_timeout=llm_timeout,
+ default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
@@ -473,36 +537,10 @@ def create_app(args):
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
- addon_params={"language": args.summary_language},
- ollama_server_infos=ollama_server_infos,
- )
- else: # azure_openai
- rag = LightRAG(
- working_dir=args.working_dir,
- workspace=args.workspace,
- llm_model_func=azure_openai_model_complete,
- chunk_token_size=int(args.chunk_size),
- chunk_overlap_token_size=int(args.chunk_overlap_size),
- llm_model_kwargs={
- "timeout": args.timeout,
+ addon_params={
+ "language": args.summary_language,
+ "entity_types": args.entity_types,
},
- llm_model_name=args.llm_model,
- llm_model_max_async=args.max_async,
- summary_max_tokens=args.max_tokens,
- embedding_func=embedding_func,
- kv_storage=args.kv_storage,
- graph_storage=args.graph_storage,
- vector_storage=args.vector_storage,
- doc_status_storage=args.doc_status_storage,
- vector_db_storage_cls_kwargs={
- "cosine_better_than_threshold": args.cosine_threshold
- },
- enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
- enable_llm_cache=args.enable_llm_cache,
- rerank_model_func=rerank_model_func,
- max_parallel_insert=args.max_parallel_insert,
- max_graph_nodes=args.max_graph_nodes,
- addon_params={"language": args.summary_language},
ollama_server_infos=ollama_server_infos,
)
@@ -709,9 +747,7 @@ def create_app(args):
}
username = form_data.username
if auth_handler.accounts.get(username) != form_data.password:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
- )
+ raise HTTPException(status_code=401, detail="Incorrect credentials")
# Regular user login
user_token = auth_handler.create_token(
@@ -754,7 +790,8 @@ def create_app(args):
"embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
- "max_tokens": args.max_tokens,
+ "summary_max_tokens": args.summary_max_tokens,
+ "summary_context_size": args.summary_context_size,
"kv_storage": args.kv_storage,
"doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage,
@@ -763,13 +800,12 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes,
- # Rerank configuration (based on whether rerank model is configured)
+ # Rerank configuration
"enable_rerank": rerank_model_func is not None,
- "rerank_model": args.rerank_model
- if rerank_model_func is not None
- else None,
+ "rerank_binding": args.rerank_binding,
+ "rerank_model": args.rerank_model if rerank_model_func else None,
"rerank_binding_host": args.rerank_binding_host
- if rerank_model_func is not None
+ if rerank_model_func
else None,
# Environment variable status (requested configuration)
"summary_language": args.summary_language,
diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py
index f02779df..42c20e6a 100644
--- a/lightrag/api/routers/graph_routes.py
+++ b/lightrag/api/routers/graph_routes.py
@@ -66,6 +66,11 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Dict[str, List[str]]: Knowledge graph for label
"""
try:
+ # Log the label parameter to check for leading spaces
+ logger.debug(
+ f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})"
+ )
+
return await rag.get_knowledge_graph(
node_label=label,
max_depth=max_depth,
diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py
index c38018f6..9ef7c55f 100644
--- a/lightrag/api/routers/ollama_api.py
+++ b/lightrag/api/routers/ollama_api.py
@@ -469,8 +469,8 @@ class OllamaAPI:
"/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
)
async def chat(raw_request: Request):
- """Process chat completion requests acting as an Ollama model
- Routes user queries through LightRAG by selecting query mode based on prefix indicators.
+ """Process chat completion requests by acting as an Ollama model.
+ Routes user queries through LightRAG by selecting query mode based on query prefix.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
Supports both application/json and application/octet-stream Content-Types.
"""
diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py
index 8c8a029d..929db019 100644
--- a/lightrag/api/run_with_gunicorn.py
+++ b/lightrag/api/run_with_gunicorn.py
@@ -153,7 +153,7 @@ def main():
# Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = (
- global_args.timeout * 2
+ global_args.timeout + 30
if global_args.timeout is not None
else get_env_value(
"TIMEOUT", DEFAULT_TIMEOUT + 30, int, special_none=True
diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py
index 90a1eb96..563cc2e4 100644
--- a/lightrag/api/utils_api.py
+++ b/lightrag/api/utils_api.py
@@ -201,6 +201,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.port}")
ASCIIColors.white(" ├─ Workers: ", end="")
ASCIIColors.yellow(f"{args.workers}")
+ ASCIIColors.white(" ├─ Timeout: ", end="")
+ ASCIIColors.yellow(f"{args.timeout}")
ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{args.cors_origins}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
@@ -238,14 +240,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.llm_binding_host}")
ASCIIColors.white(" ├─ Model: ", end="")
ASCIIColors.yellow(f"{args.llm_model}")
- ASCIIColors.white(" ├─ Temperature: ", end="")
- ASCIIColors.yellow(f"{args.temperature}")
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
ASCIIColors.yellow(f"{args.max_async}")
- ASCIIColors.white(" ├─ Max Tokens: ", end="")
- ASCIIColors.yellow(f"{args.max_tokens}")
- ASCIIColors.white(" ├─ Timeout: ", end="")
- ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
+ ASCIIColors.white(" ├─ Summary Context Size: ", end="")
+ ASCIIColors.yellow(f"{args.summary_context_size}")
ASCIIColors.white(" ├─ LLM Cache Enabled: ", end="")
ASCIIColors.yellow(f"{args.enable_llm_cache}")
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
@@ -266,6 +264,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
ASCIIColors.white(" ├─ Summary Language: ", end="")
ASCIIColors.yellow(f"{args.summary_language}")
+ ASCIIColors.white(" ├─ Entity Types: ", end="")
+ ASCIIColors.yellow(f"{args.entity_types}")
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
ASCIIColors.yellow(f"{args.max_parallel_insert}")
ASCIIColors.white(" ├─ Chunk Size: ", end="")
diff --git a/lightrag/base.py b/lightrag/base.py
index 55a7c0cc..a29bcc33 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -22,7 +22,6 @@ from .constants import (
DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_HISTORY_TURNS,
- DEFAULT_ENABLE_RERANK,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_OLLAMA_MODEL_SIZE,
@@ -143,10 +142,6 @@ class QueryParam:
history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS)))
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
- # TODO: TODO: Deprecated - ID-based filtering only applies to chunks, not entities or relations, and implemented only in PostgreSQL storage
- ids: list[str] | None = None
- """List of doc ids to filter the results."""
-
model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function.
@@ -158,9 +153,7 @@ class QueryParam:
If proivded, this will be use instead of the default vaulue from prompt template.
"""
- enable_rerank: bool = (
- os.getenv("ENABLE_RERANK", str(DEFAULT_ENABLE_RERANK).lower()).lower() == "true"
- )
+ enable_rerank: bool = os.getenv("RERANK_BY_DEFAULT", "true").lower() == "true"
"""Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued.
Default is True to enable reranking when rerank model is available.
"""
@@ -219,9 +212,16 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod
async def query(
- self, query: str, top_k: int, ids: list[str] | None = None
+ self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
- """Query the vector storage and retrieve top_k results."""
+ """Query the vector storage and retrieve top_k results.
+
+ Args:
+ query: The query string to search for
+ top_k: Number of top results to return
+ query_embedding: Optional pre-computed embedding for the query.
+ If provided, skips embedding computation for better performance.
+ """
@abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
diff --git a/lightrag/constants.py b/lightrag/constants.py
index e3ed9d7f..0b3962d0 100644
--- a/lightrag/constants.py
+++ b/lightrag/constants.py
@@ -11,10 +11,29 @@ DEFAULT_WOKERS = 2
DEFAULT_MAX_GRAPH_NODES = 1000
# Default values for extraction settings
-DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for summaries
-DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 4
+DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for document processing
DEFAULT_MAX_GLEANING = 1
-DEFAULT_SUMMARY_MAX_TOKENS = 10000 # Default maximum token size
+
+# Number of description fragments to trigger LLM summary
+DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 8
+# Max description token size to trigger LLM summary
+DEFAULT_SUMMARY_MAX_TOKENS = 1200
+# Recommended LLM summary output length in tokens
+DEFAULT_SUMMARY_LENGTH_RECOMMENDED = 600
+# Maximum token size sent to LLM for summary
+DEFAULT_SUMMARY_CONTEXT_SIZE = 12000
+# Default entities to extract if ENTITY_TYPES is not specified in .env
+DEFAULT_ENTITY_TYPES = [
+ "Organization",
+ "Person",
+ "Location",
+ "Event",
+ "Technology",
+ "Equipment",
+ "Product",
+ "Document",
+ "Category",
+]
# Separator for graph fields
GRAPH_FIELD_SEP = ""
@@ -32,8 +51,8 @@ DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
DEFAULT_HISTORY_TURNS = 0
# Rerank configuration defaults
-DEFAULT_ENABLE_RERANK = True
DEFAULT_MIN_RERANK_SCORE = 0.0
+DEFAULT_RERANK_BINDING = "null"
# File path configuration for vector and graph database(Should not be changed, used in Milvus Schema)
DEFAULT_MAX_FILE_PATH_LENGTH = 32768
@@ -49,8 +68,12 @@ DEFAULT_MAX_PARALLEL_INSERT = 2 # Default maximum parallel insert operations
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions
DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations
-# Ollama Server Timetout in seconds
-DEFAULT_TIMEOUT = 150
+# Gunicorn worker timeout
+DEFAULT_TIMEOUT = 210
+
+# Default llm and embedding timeout
+DEFAULT_LLM_TIMEOUT = 180
+DEFAULT_EMBEDDING_TIMEOUT = 30
# Logging configuration defaults
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py
index ae756f85..d57df1ac 100644
--- a/lightrag/exceptions.py
+++ b/lightrag/exceptions.py
@@ -58,3 +58,41 @@ class RateLimitError(APIStatusError):
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)
+
+
+class StorageNotInitializedError(RuntimeError):
+ """Raised when storage operations are attempted before initialization."""
+
+ def __init__(self, storage_type: str = "Storage"):
+ super().__init__(
+ f"{storage_type} not initialized. Please ensure proper initialization:\n"
+ f"\n"
+ f" rag = LightRAG(...)\n"
+ f" await rag.initialize_storages() # Required\n"
+ f" \n"
+ f" from lightrag.kg.shared_storage import initialize_pipeline_status\n"
+ f" await initialize_pipeline_status() # Required for pipeline operations\n"
+ f"\n"
+ f"See: https://github.com/HKUDS/LightRAG#important-initialization-requirements"
+ )
+
+
+class PipelineNotInitializedError(KeyError):
+ """Raised when pipeline status is accessed before initialization."""
+
+ def __init__(self, namespace: str = ""):
+ msg = (
+ f"Pipeline namespace '{namespace}' not found. "
+ f"This usually means pipeline status was not initialized.\n"
+ f"\n"
+ f"Please call 'await initialize_pipeline_status()' after initializing storages:\n"
+ f"\n"
+ f" from lightrag.kg.shared_storage import initialize_pipeline_status\n"
+ f" await initialize_pipeline_status()\n"
+ f"\n"
+ f"Full initialization sequence:\n"
+ f" rag = LightRAG(...)\n"
+ f" await rag.initialize_storages()\n"
+ f" await initialize_pipeline_status()"
+ )
+ super().__init__(msg)
diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py
index 5098ebf7..7d6a6dac 100644
--- a/lightrag/kg/faiss_impl.py
+++ b/lightrag/kg/faiss_impl.py
@@ -180,16 +180,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
return [m["__id__"] for m in list_data]
async def query(
- self, query: str, top_k: int, ids: list[str] | None = None
+ self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
"""
Search by a textual query; returns top_k results with their metadata + similarity distance.
"""
- embedding = await self.embedding_func(
- [query], _priority=5
- ) # higher priority for query
- # embedding is shape (1, dim)
- embedding = np.array(embedding, dtype=np.float32)
+ if query_embedding is not None:
+ embedding = np.array([query_embedding], dtype=np.float32)
+ else:
+ embedding = await self.embedding_func(
+ [query], _priority=5
+ ) # higher priority for query
+ # embedding is shape (1, dim)
+ embedding = np.array(embedding, dtype=np.float32)
+
faiss.normalize_L2(embedding) # we do in-place normalization
# Perform the similarity search
diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py
index 13054cde..5464d0c3 100644
--- a/lightrag/kg/json_doc_status_impl.py
+++ b/lightrag/kg/json_doc_status_impl.py
@@ -13,6 +13,7 @@ from lightrag.utils import (
write_json,
get_pinyin_sort_key,
)
+from lightrag.exceptions import StorageNotInitializedError
from .shared_storage import (
get_namespace_data,
get_storage_lock,
@@ -65,11 +66,15 @@ class JsonDocStatusStorage(DocStatusStorage):
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
+ if self._storage_lock is None:
+ raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
+ if self._storage_lock is None:
+ raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
for id in ids:
data = self._data.get(id, None)
@@ -80,6 +85,8 @@ class JsonDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus}
+ if self._storage_lock is None:
+ raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
for doc in self._data.values():
counts[doc["status"]] += 1
@@ -166,6 +173,8 @@ class JsonDocStatusStorage(DocStatusStorage):
logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
+ if self._storage_lock is None:
+ raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py
index ca3aa453..553ba417 100644
--- a/lightrag/kg/json_kv_impl.py
+++ b/lightrag/kg/json_kv_impl.py
@@ -10,6 +10,7 @@ from lightrag.utils import (
logger,
write_json,
)
+from lightrag.exceptions import StorageNotInitializedError
from .shared_storage import (
get_namespace_data,
get_storage_lock,
@@ -154,6 +155,8 @@ class JsonKVStorage(BaseKVStorage):
logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
+ if self._storage_lock is None:
+ raise StorageNotInitializedError("JsonKVStorage")
async with self._storage_lock:
# Add timestamps to data based on whether key exists
for k, v in data.items():
diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py
index 82dce30c..f2368afe 100644
--- a/lightrag/kg/milvus_impl.py
+++ b/lightrag/kg/milvus_impl.py
@@ -1047,14 +1047,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return results
async def query(
- self, query: str, top_k: int, ids: list[str] | None = None
+ self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
# Ensure collection is loaded before querying
self._ensure_collection_loaded()
- embedding = await self.embedding_func(
- [query], _priority=5
- ) # higher priority for query
+ # Use provided embedding or compute it
+ if query_embedding is not None:
+ embedding = [query_embedding] # Milvus expects a list of embeddings
+ else:
+ embedding = await self.embedding_func(
+ [query], _priority=5
+ ) # higher priority for query
# Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields)
diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py
index b8d30c44..9e4d7e67 100644
--- a/lightrag/kg/mongo_impl.py
+++ b/lightrag/kg/mongo_impl.py
@@ -280,6 +280,30 @@ class MongoDocStatusStorage(DocStatusStorage):
db: AsyncDatabase = field(default=None)
_data: AsyncCollection = field(default=None)
+ def _prepare_doc_status_data(self, doc: dict[str, Any]) -> dict[str, Any]:
+ """Normalize and migrate a raw Mongo document to DocProcessingStatus-compatible dict."""
+ # Make a copy of the data to avoid modifying the original
+ data = doc.copy()
+ # Remove deprecated content field if it exists
+ data.pop("content", None)
+ # Remove MongoDB _id field if it exists
+ data.pop("_id", None)
+ # If file_path is not in data, use document id as file path
+ if "file_path" not in data:
+ data["file_path"] = "no-file-path"
+ # Ensure new fields exist with default values
+ if "metadata" not in data:
+ data["metadata"] = {}
+ if "error_msg" not in data:
+ data["error_msg"] = None
+ # Backward compatibility: migrate legacy 'error' field to 'error_msg'
+ if "error" in data:
+ if "error_msg" not in data or data["error_msg"] in (None, ""):
+ data["error_msg"] = data.pop("error")
+ else:
+ data.pop("error", None)
+ return data
+
def __init__(self, namespace, global_config, embedding_func, workspace=None):
super().__init__(
namespace=namespace,
@@ -389,20 +413,7 @@ class MongoDocStatusStorage(DocStatusStorage):
processed_result = {}
for doc in result:
try:
- # Make a copy of the data to avoid modifying the original
- data = doc.copy()
- # Remove deprecated content field if it exists
- data.pop("content", None)
- # Remove MongoDB _id field if it exists
- data.pop("_id", None)
- # If file_path is not in data, use document id as file path
- if "file_path" not in data:
- data["file_path"] = "no-file-path"
- # Ensure new fields exist with default values
- if "metadata" not in data:
- data["metadata"] = {}
- if "error_msg" not in data:
- data["error_msg"] = None
+ data = self._prepare_doc_status_data(doc)
processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(
@@ -420,20 +431,7 @@ class MongoDocStatusStorage(DocStatusStorage):
processed_result = {}
for doc in result:
try:
- # Make a copy of the data to avoid modifying the original
- data = doc.copy()
- # Remove deprecated content field if it exists
- data.pop("content", None)
- # Remove MongoDB _id field if it exists
- data.pop("_id", None)
- # If file_path is not in data, use document id as file path
- if "file_path" not in data:
- data["file_path"] = "no-file-path"
- # Ensure new fields exist with default values
- if "metadata" not in data:
- data["metadata"] = {}
- if "error_msg" not in data:
- data["error_msg"] = None
+ data = self._prepare_doc_status_data(doc)
processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(
@@ -661,20 +659,7 @@ class MongoDocStatusStorage(DocStatusStorage):
try:
doc_id = doc["_id"]
- # Make a copy of the data to avoid modifying the original
- data = doc.copy()
- # Remove deprecated content field if it exists
- data.pop("content", None)
- # Remove MongoDB _id field if it exists
- data.pop("_id", None)
- # If file_path is not in data, use document id as file path
- if "file_path" not in data:
- data["file_path"] = "no-file-path"
- # Ensure new fields exist with default values
- if "metadata" not in data:
- data["metadata"] = {}
- if "error_msg" not in data:
- data["error_msg"] = None
+ data = self._prepare_doc_status_data(doc)
doc_status = DocProcessingStatus(**data)
documents.append((doc_id, doc_status))
@@ -1825,16 +1810,22 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data
async def query(
- self, query: str, top_k: int, ids: list[str] | None = None
+ self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search."""
- # Generate the embedding
- embedding = await self.embedding_func(
- [query], _priority=5
- ) # higher priority for query
-
- # Convert numpy array to a list to ensure compatibility with MongoDB
- query_vector = embedding[0].tolist()
+ if query_embedding is not None:
+ # Convert numpy array to list if needed for MongoDB compatibility
+ if hasattr(query_embedding, "tolist"):
+ query_vector = query_embedding.tolist()
+ else:
+ query_vector = list(query_embedding)
+ else:
+ # Generate the embedding
+ embedding = await self.embedding_func(
+ [query], _priority=5
+ ) # higher priority for query
+ # Convert numpy array to a list to ensure compatibility with MongoDB
+ query_vector = embedding[0].tolist()
# Define the aggregation pipeline with the converted query vector
pipeline = [
diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py
index 5bec06f4..def5a83d 100644
--- a/lightrag/kg/nano_vector_db_impl.py
+++ b/lightrag/kg/nano_vector_db_impl.py
@@ -137,13 +137,17 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
async def query(
- self, query: str, top_k: int, ids: list[str] | None = None
+ self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
- # Execute embedding outside of lock to avoid improve cocurrent
- embedding = await self.embedding_func(
- [query], _priority=5
- ) # higher priority for query
- embedding = embedding[0]
+ # Use provided embedding or compute it
+ if query_embedding is not None:
+ embedding = query_embedding
+ else:
+ # Execute embedding outside of lock to avoid improve cocurrent
+ embedding = await self.embedding_func(
+ [query], _priority=5
+ ) # higher priority for query
+ embedding = embedding[0]
client = await self._get_client()
results = client.query(
diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py
index 88a75ba5..03a26f54 100644
--- a/lightrag/kg/postgres_impl.py
+++ b/lightrag/kg/postgres_impl.py
@@ -2005,18 +2005,21 @@ class PGVectorStorage(BaseVectorStorage):
#################### query method ###############
async def query(
- self, query: str, top_k: int, ids: list[str] | None = None
+ self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
- embeddings = await self.embedding_func(
- [query], _priority=5
- ) # higher priority for query
- embedding = embeddings[0]
+ if query_embedding is not None:
+ embedding = query_embedding
+ else:
+ embeddings = await self.embedding_func(
+ [query], _priority=5
+ ) # higher priority for query
+ embedding = embeddings[0]
+
embedding_string = ",".join(map(str, embedding))
- # Use parameterized document IDs (None means search across all documents)
+
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
params = {
"workspace": self.workspace,
- "doc_ids": ids,
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k,
}
@@ -4582,85 +4585,34 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time
""",
"relationships": """
- WITH relevant_chunks AS (SELECT id as chunk_id
- FROM LIGHTRAG_VDB_CHUNKS
- WHERE $2
- :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
- )
- , rc AS (
- SELECT array_agg(chunk_id) AS chunk_arr
- FROM relevant_chunks
- ), cand AS (
- SELECT
- r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist
+ SELECT r.source_id AS src_id,
+ r.target_id AS tgt_id,
+ EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_RELATION r
WHERE r.workspace = $1
+ AND r.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
- LIMIT ($4 * 50)
- )
- SELECT c.src_id,
- c.tgt_id,
- EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
- FROM cand c
- JOIN rc ON TRUE
- WHERE c.dist < $3
- AND c.chunk_ids && (rc.chunk_arr::varchar[])
- ORDER BY c.dist, c.id
- LIMIT $4;
+ LIMIT $3;
""",
"entities": """
- WITH relevant_chunks AS (SELECT id as chunk_id
- FROM LIGHTRAG_VDB_CHUNKS
- WHERE $2
- :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
- )
- , rc AS (
- SELECT array_agg(chunk_id) AS chunk_arr
- FROM relevant_chunks
- ), cand AS (
- SELECT
- e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist
+ SELECT e.entity_name,
+ EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_ENTITY e
WHERE e.workspace = $1
+ AND e.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
- LIMIT ($4 * 50)
- )
- SELECT c.entity_name,
- EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
- FROM cand c
- JOIN rc ON TRUE
- WHERE c.dist < $3
- AND c.chunk_ids && (rc.chunk_arr::varchar[])
- ORDER BY c.dist, c.id
- LIMIT $4;
+ LIMIT $3;
""",
"chunks": """
- WITH relevant_chunks AS (SELECT id as chunk_id
- FROM LIGHTRAG_VDB_CHUNKS
- WHERE $2
- :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
- )
- , rc AS (
- SELECT array_agg(chunk_id) AS chunk_arr
- FROM relevant_chunks
- ), cand AS (
- SELECT
- id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist
- FROM LIGHTRAG_VDB_CHUNKS
- WHERE workspace = $1
- ORDER BY content_vector <=> '[{embedding_string}]'::vector
- LIMIT ($4 * 50)
- )
SELECT c.id,
c.content,
c.file_path,
- EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
- FROM cand c
- JOIN rc ON TRUE
- WHERE c.dist < $3
- AND c.id = ANY (rc.chunk_arr)
- ORDER BY c.dist, c.id
- LIMIT $4;
+ EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
+ FROM LIGHTRAG_VDB_CHUNKS c
+ WHERE c.workspace = $1
+ AND c.content_vector <=> '[{embedding_string}]'::vector < $2
+ ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
+ LIMIT $3;
""",
# DROP tables
"drop_specifiy_table_workspace": """
diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py
index 4ece163c..dad95bbc 100644
--- a/lightrag/kg/qdrant_impl.py
+++ b/lightrag/kg/qdrant_impl.py
@@ -200,14 +200,19 @@ class QdrantVectorDBStorage(BaseVectorStorage):
return results
async def query(
- self, query: str, top_k: int, ids: list[str] | None = None
+ self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
- embedding = await self.embedding_func(
- [query], _priority=5
- ) # higher priority for query
+ if query_embedding is not None:
+ embedding = query_embedding
+ else:
+ embedding_result = await self.embedding_func(
+ [query], _priority=5
+ ) # higher priority for query
+ embedding = embedding_result[0]
+
results = self._client.search(
collection_name=self.final_namespace,
- query_vector=embedding[0],
+ query_vector=embedding,
limit=top_k,
with_payload=True,
score_threshold=self.cosine_better_than_threshold,
diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py
index 228bf272..e20dce52 100644
--- a/lightrag/kg/shared_storage.py
+++ b/lightrag/kg/shared_storage.py
@@ -8,6 +8,8 @@ import time
import logging
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic
+from lightrag.exceptions import PipelineNotInitializedError
+
# Define a direct print function for critical logs that must be visible in all processes
def direct_log(message, enable_output: bool = False, level: str = "DEBUG"):
@@ -1057,7 +1059,7 @@ async def initialize_pipeline_status():
Initialize pipeline namespace with default values.
This function is called during FASTAPI lifespan for each worker.
"""
- pipeline_namespace = await get_namespace_data("pipeline_status")
+ pipeline_namespace = await get_namespace_data("pipeline_status", first_init=True)
async with get_internal_lock():
# Check if already initialized by checking for required fields
@@ -1192,8 +1194,16 @@ async def try_initialize_namespace(namespace: str) -> bool:
return False
-async def get_namespace_data(namespace: str) -> Dict[str, Any]:
- """get the shared data reference for specific namespace"""
+async def get_namespace_data(
+ namespace: str, first_init: bool = False
+) -> Dict[str, Any]:
+ """get the shared data reference for specific namespace
+
+ Args:
+ namespace: The namespace to retrieve
+ allow_create: If True, allows creation of the namespace if it doesn't exist.
+ Used internally by initialize_pipeline_status().
+ """
if _shared_dicts is None:
direct_log(
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
@@ -1203,6 +1213,13 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
async with get_internal_lock():
if namespace not in _shared_dicts:
+ # Special handling for pipeline_status namespace
+ if namespace == "pipeline_status" and not first_init:
+ # Check if pipeline_status should have been initialized but wasn't
+ # This helps users understand they need to call initialize_pipeline_status()
+ raise PipelineNotInitializedError(namespace)
+
+ # For other namespaces or when allow_create=True, create them dynamically
if _is_multiprocess and _manager is not None:
_shared_dicts[namespace] = _manager.dict()
else:
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 274f4c21..b4842924 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -9,7 +9,6 @@ import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from functools import partial
-from pathlib import Path
from typing import (
Any,
AsyncIterator,
@@ -35,9 +34,15 @@ from lightrag.constants import (
DEFAULT_KG_CHUNK_PICK_METHOD,
DEFAULT_MIN_RERANK_SCORE,
DEFAULT_SUMMARY_MAX_TOKENS,
+ DEFAULT_SUMMARY_CONTEXT_SIZE,
+ DEFAULT_SUMMARY_LENGTH_RECOMMENDED,
DEFAULT_MAX_ASYNC,
DEFAULT_MAX_PARALLEL_INSERT,
DEFAULT_MAX_GRAPH_NODES,
+ DEFAULT_ENTITY_TYPES,
+ DEFAULT_SUMMARY_LANGUAGE,
+ DEFAULT_LLM_TIMEOUT,
+ DEFAULT_EMBEDDING_TIMEOUT,
)
from lightrag.utils import get_env_value
@@ -278,6 +283,10 @@ class LightRAG:
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
+ default_embedding_timeout: int = field(
+ default=int(os.getenv("EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT))
+ )
+
# LLM Configuration
# ---
@@ -288,10 +297,22 @@ class LightRAG:
"""Name of the LLM model used for generating responses."""
summary_max_tokens: int = field(
- default=int(os.getenv("MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS))
+ default=int(os.getenv("SUMMARY_MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS))
+ )
+ """Maximum tokens allowed for entity/relation description."""
+
+ summary_context_size: int = field(
+ default=int(os.getenv("SUMMARY_CONTEXT_SIZE", DEFAULT_SUMMARY_CONTEXT_SIZE))
)
"""Maximum number of tokens allowed per LLM response."""
+ summary_length_recommended: int = field(
+ default=int(
+ os.getenv("SUMMARY_LENGTH_RECOMMENDED", DEFAULT_SUMMARY_LENGTH_RECOMMENDED)
+ )
+ )
+ """Recommended length of LLM summary output."""
+
llm_model_max_async: int = field(
default=int(os.getenv("MAX_ASYNC", DEFAULT_MAX_ASYNC))
)
@@ -300,6 +321,10 @@ class LightRAG:
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
+ default_llm_timeout: int = field(
+ default=int(os.getenv("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT))
+ )
+
# Rerank Configuration
# ---
@@ -338,7 +363,10 @@ class LightRAG:
addon_params: dict[str, Any] = field(
default_factory=lambda: {
- "language": get_env_value("SUMMARY_LANGUAGE", "English", str)
+ "language": get_env_value(
+ "SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE, str
+ ),
+ "entity_types": get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES, list),
}
)
@@ -421,6 +449,20 @@ class LightRAG:
if self.ollama_server_infos is None:
self.ollama_server_infos = OllamaServerInfos()
+ # Validate config
+ if self.force_llm_summary_on_merge < 3:
+ logger.warning(
+ f"force_llm_summary_on_merge should be at least 3, got {self.force_llm_summary_on_merge}"
+ )
+ if self.summary_context_size > self.max_total_tokens:
+ logger.warning(
+ f"summary_context_size({self.summary_context_size}) should no greater than max_total_tokens({self.max_total_tokens})"
+ )
+ if self.summary_length_recommended > self.summary_max_tokens:
+ logger.warning(
+ f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
+ )
+
# Fix global_config now
global_config = asdict(self)
@@ -429,7 +471,9 @@ class LightRAG:
# Init Embedding
self.embedding_func = priority_limit_async_func_call(
- self.embedding_func_max_async
+ self.embedding_func_max_async,
+ llm_timeout=self.default_embedding_timeout,
+ queue_name="Embedding func:",
)(self.embedding_func)
# Initialize all storages
@@ -522,7 +566,12 @@ class LightRAG:
# Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache
- self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)(
+ # Get timeout from LLM model kwargs for dynamic timeout calculation
+ self.llm_model_func = priority_limit_async_func_call(
+ self.llm_model_max_async,
+ llm_timeout=self.default_llm_timeout,
+ queue_name="LLM func:",
+ )(
partial(
self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
@@ -530,14 +579,6 @@ class LightRAG:
)
)
- # Init Rerank
- if self.rerank_model_func:
- logger.info("Rerank model initialized for improved retrieval quality")
- else:
- logger.warning(
- "Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
- )
-
self._storages_status = StoragesStatus.CREATED
async def initialize_storages(self):
@@ -2573,117 +2614,111 @@ class LightRAG:
relationships_to_delete = set()
relationships_to_rebuild = {} # (src, tgt) -> remaining_chunk_ids
- # Use graph database lock to ensure atomic merges and updates
+ try:
+ # Get affected entities and relations from full_entities and full_relations storage
+ doc_entities_data = await self.full_entities.get_by_id(doc_id)
+ doc_relations_data = await self.full_relations.get_by_id(doc_id)
+
+ affected_nodes = []
+ affected_edges = []
+
+ # Get entity data from graph storage using entity names from full_entities
+ if doc_entities_data and "entity_names" in doc_entities_data:
+ entity_names = doc_entities_data["entity_names"]
+ # get_nodes_batch returns dict[str, dict], need to convert to list[dict]
+ nodes_dict = await self.chunk_entity_relation_graph.get_nodes_batch(
+ entity_names
+ )
+ for entity_name in entity_names:
+ node_data = nodes_dict.get(entity_name)
+ if node_data:
+ # Ensure compatibility with existing logic that expects "id" field
+ if "id" not in node_data:
+ node_data["id"] = entity_name
+ affected_nodes.append(node_data)
+
+ # Get relation data from graph storage using relation pairs from full_relations
+ if doc_relations_data and "relation_pairs" in doc_relations_data:
+ relation_pairs = doc_relations_data["relation_pairs"]
+ edge_pairs_dicts = [
+ {"src": pair[0], "tgt": pair[1]} for pair in relation_pairs
+ ]
+ # get_edges_batch returns dict[tuple[str, str], dict], need to convert to list[dict]
+ edges_dict = await self.chunk_entity_relation_graph.get_edges_batch(
+ edge_pairs_dicts
+ )
+
+ for pair in relation_pairs:
+ src, tgt = pair[0], pair[1]
+ edge_key = (src, tgt)
+ edge_data = edges_dict.get(edge_key)
+ if edge_data:
+ # Ensure compatibility with existing logic that expects "source" and "target" fields
+ if "source" not in edge_data:
+ edge_data["source"] = src
+ if "target" not in edge_data:
+ edge_data["target"] = tgt
+ affected_edges.append(edge_data)
+
+ except Exception as e:
+ logger.error(f"Failed to analyze affected graph elements: {e}")
+ raise Exception(f"Failed to analyze graph dependencies: {e}") from e
+
+ try:
+ # Process entities
+ for node_data in affected_nodes:
+ node_label = node_data.get("entity_id")
+ if node_label and "source_id" in node_data:
+ sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
+ remaining_sources = sources - chunk_ids
+
+ if not remaining_sources:
+ entities_to_delete.add(node_label)
+ elif remaining_sources != sources:
+ entities_to_rebuild[node_label] = remaining_sources
+
+ async with pipeline_status_lock:
+ log_message = f"Found {len(entities_to_rebuild)} affected entities"
+ logger.info(log_message)
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
+
+ # Process relationships
+ for edge_data in affected_edges:
+ src = edge_data.get("source")
+ tgt = edge_data.get("target")
+
+ if src and tgt and "source_id" in edge_data:
+ edge_tuple = tuple(sorted((src, tgt)))
+ if (
+ edge_tuple in relationships_to_delete
+ or edge_tuple in relationships_to_rebuild
+ ):
+ continue
+
+ sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
+ remaining_sources = sources - chunk_ids
+
+ if not remaining_sources:
+ relationships_to_delete.add(edge_tuple)
+ elif remaining_sources != sources:
+ relationships_to_rebuild[edge_tuple] = remaining_sources
+
+ async with pipeline_status_lock:
+ log_message = (
+ f"Found {len(relationships_to_rebuild)} affected relations"
+ )
+ logger.info(log_message)
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
+
+ except Exception as e:
+ logger.error(f"Failed to process graph analysis results: {e}")
+ raise Exception(f"Failed to process graph dependencies: {e}") from e
+
+ # Use graph database lock to prevent dirty read
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
- try:
- # Get affected entities and relations from full_entities and full_relations storage
- doc_entities_data = await self.full_entities.get_by_id(doc_id)
- doc_relations_data = await self.full_relations.get_by_id(doc_id)
-
- affected_nodes = []
- affected_edges = []
-
- # Get entity data from graph storage using entity names from full_entities
- if doc_entities_data and "entity_names" in doc_entities_data:
- entity_names = doc_entities_data["entity_names"]
- # get_nodes_batch returns dict[str, dict], need to convert to list[dict]
- nodes_dict = (
- await self.chunk_entity_relation_graph.get_nodes_batch(
- entity_names
- )
- )
- for entity_name in entity_names:
- node_data = nodes_dict.get(entity_name)
- if node_data:
- # Ensure compatibility with existing logic that expects "id" field
- if "id" not in node_data:
- node_data["id"] = entity_name
- affected_nodes.append(node_data)
-
- # Get relation data from graph storage using relation pairs from full_relations
- if doc_relations_data and "relation_pairs" in doc_relations_data:
- relation_pairs = doc_relations_data["relation_pairs"]
- edge_pairs_dicts = [
- {"src": pair[0], "tgt": pair[1]} for pair in relation_pairs
- ]
- # get_edges_batch returns dict[tuple[str, str], dict], need to convert to list[dict]
- edges_dict = (
- await self.chunk_entity_relation_graph.get_edges_batch(
- edge_pairs_dicts
- )
- )
-
- for pair in relation_pairs:
- src, tgt = pair[0], pair[1]
- edge_key = (src, tgt)
- edge_data = edges_dict.get(edge_key)
- if edge_data:
- # Ensure compatibility with existing logic that expects "source" and "target" fields
- if "source" not in edge_data:
- edge_data["source"] = src
- if "target" not in edge_data:
- edge_data["target"] = tgt
- affected_edges.append(edge_data)
-
- except Exception as e:
- logger.error(f"Failed to analyze affected graph elements: {e}")
- raise Exception(f"Failed to analyze graph dependencies: {e}") from e
-
- try:
- # Process entities
- for node_data in affected_nodes:
- node_label = node_data.get("entity_id")
- if node_label and "source_id" in node_data:
- sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
- remaining_sources = sources - chunk_ids
-
- if not remaining_sources:
- entities_to_delete.add(node_label)
- elif remaining_sources != sources:
- entities_to_rebuild[node_label] = remaining_sources
-
- async with pipeline_status_lock:
- log_message = (
- f"Found {len(entities_to_rebuild)} affected entities"
- )
- logger.info(log_message)
- pipeline_status["latest_message"] = log_message
- pipeline_status["history_messages"].append(log_message)
-
- # Process relationships
- for edge_data in affected_edges:
- src = edge_data.get("source")
- tgt = edge_data.get("target")
-
- if src and tgt and "source_id" in edge_data:
- edge_tuple = tuple(sorted((src, tgt)))
- if (
- edge_tuple in relationships_to_delete
- or edge_tuple in relationships_to_rebuild
- ):
- continue
-
- sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
- remaining_sources = sources - chunk_ids
-
- if not remaining_sources:
- relationships_to_delete.add(edge_tuple)
- elif remaining_sources != sources:
- relationships_to_rebuild[edge_tuple] = remaining_sources
-
- async with pipeline_status_lock:
- log_message = (
- f"Found {len(relationships_to_rebuild)} affected relations"
- )
- logger.info(log_message)
- pipeline_status["latest_message"] = log_message
- pipeline_status["history_messages"].append(log_message)
-
- except Exception as e:
- logger.error(f"Failed to process graph analysis results: {e}")
- raise Exception(f"Failed to process graph dependencies: {e}") from e
-
# 5. Delete chunks from storage
if chunk_ids:
try:
@@ -2754,27 +2789,28 @@ class LightRAG:
logger.error(f"Failed to delete relationships: {e}")
raise Exception(f"Failed to delete relationships: {e}") from e
- # 8. Rebuild entities and relationships from remaining chunks
- if entities_to_rebuild or relationships_to_rebuild:
- try:
- await _rebuild_knowledge_from_chunks(
- entities_to_rebuild=entities_to_rebuild,
- relationships_to_rebuild=relationships_to_rebuild,
- knowledge_graph_inst=self.chunk_entity_relation_graph,
- entities_vdb=self.entities_vdb,
- relationships_vdb=self.relationships_vdb,
- text_chunks_storage=self.text_chunks,
- llm_response_cache=self.llm_response_cache,
- global_config=asdict(self),
- pipeline_status=pipeline_status,
- pipeline_status_lock=pipeline_status_lock,
- )
+ # Persist changes to graph database before releasing graph database lock
+ await self._insert_done()
- except Exception as e:
- logger.error(f"Failed to rebuild knowledge from chunks: {e}")
- raise Exception(
- f"Failed to rebuild knowledge graph: {e}"
- ) from e
+ # 8. Rebuild entities and relationships from remaining chunks
+ if entities_to_rebuild or relationships_to_rebuild:
+ try:
+ await _rebuild_knowledge_from_chunks(
+ entities_to_rebuild=entities_to_rebuild,
+ relationships_to_rebuild=relationships_to_rebuild,
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
+ entities_vdb=self.entities_vdb,
+ relationships_vdb=self.relationships_vdb,
+ text_chunks_storage=self.text_chunks,
+ llm_response_cache=self.llm_response_cache,
+ global_config=asdict(self),
+ pipeline_status=pipeline_status,
+ pipeline_status_lock=pipeline_status_lock,
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to rebuild knowledge from chunks: {e}")
+ raise Exception(f"Failed to rebuild knowledge graph: {e}") from e
# 9. Delete from full_entities and full_relations storage
try:
diff --git a/lightrag/llm/Readme.md b/lightrag/llm/Readme.md
index c907fd4d..fc00d071 100644
--- a/lightrag/llm/Readme.md
+++ b/lightrag/llm/Readme.md
@@ -36,7 +36,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
llm_instance = OpenAI(
model="gpt-4",
api_key="your-openai-key",
- temperature=0.7,
)
kwargs['llm_instance'] = llm_instance
@@ -91,7 +90,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
api_base=settings.LITELLM_URL,
api_key=settings.LITELLM_KEY,
- temperature=0.7,
)
kwargs['llm_instance'] = llm_instance
diff --git a/lightrag/llm/anthropic.py b/lightrag/llm/anthropic.py
index 7878c8f0..98a997d5 100644
--- a/lightrag/llm/anthropic.py
+++ b/lightrag/llm/anthropic.py
@@ -77,14 +77,23 @@ async def anthropic_complete_if_cache(
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("anthropic").setLevel(logging.INFO)
+ kwargs.pop("hashing_kv", None)
+ kwargs.pop("keyword_extraction", None)
+ timeout = kwargs.pop("timeout", None)
+
anthropic_async_client = (
- AsyncAnthropic(default_headers=default_headers, api_key=api_key)
+ AsyncAnthropic(
+ default_headers=default_headers, api_key=api_key, timeout=timeout
+ )
if base_url is None
else AsyncAnthropic(
- base_url=base_url, default_headers=default_headers, api_key=api_key
+ base_url=base_url,
+ default_headers=default_headers,
+ api_key=api_key,
+ timeout=timeout,
)
)
- kwargs.pop("hashing_kv", None)
+
messages: list[dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py
index 60d2c18e..0ede0824 100644
--- a/lightrag/llm/azure_openai.py
+++ b/lightrag/llm/azure_openai.py
@@ -59,13 +59,17 @@ async def azure_openai_complete_if_cache(
or os.getenv("OPENAI_API_VERSION")
)
+ kwargs.pop("hashing_kv", None)
+ kwargs.pop("keyword_extraction", None)
+ timeout = kwargs.pop("timeout", None)
+
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=deployment,
api_key=api_key,
api_version=api_version,
+ timeout=timeout,
)
- kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py
index 827620ee..c2f2c9d7 100644
--- a/lightrag/llm/binding_options.py
+++ b/lightrag/llm/binding_options.py
@@ -99,7 +99,7 @@ class BindingOptions:
group = parser.add_argument_group(f"{cls._binding_name} binding options")
for arg_item in cls.args_env_name_type_value():
# Handle JSON parsing for list types
- if arg_item["type"] == List[str]:
+ if arg_item["type"] is List[str]:
def json_list_parser(value):
try:
@@ -126,6 +126,34 @@ class BindingOptions:
default=env_value,
help=arg_item["help"],
)
+ # Handle JSON parsing for dict types
+ elif arg_item["type"] is dict:
+
+ def json_dict_parser(value):
+ try:
+ parsed = json.loads(value)
+ if not isinstance(parsed, dict):
+ raise argparse.ArgumentTypeError(
+ f"Expected JSON object, got {type(parsed).__name__}"
+ )
+ return parsed
+ except json.JSONDecodeError as e:
+ raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
+
+ # Get environment variable with JSON parsing
+ env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
+ if env_value is not argparse.SUPPRESS:
+ try:
+ env_value = json_dict_parser(env_value)
+ except argparse.ArgumentTypeError:
+ env_value = argparse.SUPPRESS
+
+ group.add_argument(
+ f"--{arg_item['argname']}",
+ type=json_dict_parser,
+ default=env_value,
+ help=arg_item["help"],
+ )
else:
group.add_argument(
f"--{arg_item['argname']}",
@@ -234,8 +262,8 @@ class BindingOptions:
if arg_item["help"]:
sample_stream.write(f"# {arg_item['help']}\n")
- # Handle JSON formatting for list types
- if arg_item["type"] == List[str]:
+ # Handle JSON formatting for list and dict types
+ if arg_item["type"] is List[str] or arg_item["type"] is dict:
default_value = json.dumps(arg_item["default"])
else:
default_value = arg_item["default"]
@@ -431,6 +459,8 @@ class OpenAILLMOptions(BindingOptions):
stop: List[str] = field(default_factory=list) # Stop sequences
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0 to 2.0)
top_p: float = 1.0 # Nucleus sampling parameter (0.0 to 1.0)
+ max_tokens: int = None # Maximum number of tokens to generate(deprecated, use max_completion_tokens instead)
+ extra_body: dict = None # Extra body parameters for OpenRouter of vLLM
# Help descriptions
_help: ClassVar[dict[str, str]] = {
@@ -443,6 +473,8 @@ class OpenAILLMOptions(BindingOptions):
"stop": 'Stop sequences (JSON array of strings, e.g., \'["", "\\n\\n"]\')',
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
"top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)",
+ "max_tokens": "Maximum number of tokens to generate (deprecated, use max_completion_tokens instead)",
+ "extra_body": 'Extra body parameters for OpenRouter of vLLM (JSON dict, e.g., \'"reasoning": {"reasoning": {"enabled": false}}\')',
}
@@ -493,6 +525,8 @@ if __name__ == "__main__":
"1000",
"--openai-llm-stop",
'["", "\\n\\n"]',
+ "--openai-llm-reasoning",
+ '{"effort": "high", "max_tokens": 2000, "exclude": false, "enabled": true}',
]
)
print("Final args for LLM and Embedding:")
@@ -518,5 +552,100 @@ if __name__ == "__main__":
print("\nOpenAI LLM options instance:")
print(openai_options.asdict())
+ # Test creating OpenAI options instance with reasoning parameter
+ openai_options_with_reasoning = OpenAILLMOptions(
+ temperature=0.9,
+ max_completion_tokens=2000,
+ reasoning={
+ "effort": "medium",
+ "max_tokens": 1500,
+ "exclude": True,
+ "enabled": True,
+ },
+ )
+ print("\nOpenAI LLM options instance with reasoning:")
+ print(openai_options_with_reasoning.asdict())
+
+ # Test dict parsing functionality
+ print("\n" + "=" * 50)
+ print("TESTING DICT PARSING FUNCTIONALITY")
+ print("=" * 50)
+
+ # Test valid JSON dict parsing
+ test_parser = ArgumentParser(description="Test dict parsing")
+ OpenAILLMOptions.add_args(test_parser)
+
+ try:
+ test_args = test_parser.parse_args(
+ ["--openai-llm-reasoning", '{"effort": "low", "max_tokens": 1000}']
+ )
+ print("✓ Valid JSON dict parsing successful:")
+ print(
+ f" Parsed reasoning: {OpenAILLMOptions.options_dict(test_args)['reasoning']}"
+ )
+ except Exception as e:
+ print(f"✗ Valid JSON dict parsing failed: {e}")
+
+ # Test invalid JSON dict parsing
+ try:
+ test_args = test_parser.parse_args(
+ [
+ "--openai-llm-reasoning",
+ '{"effort": "low", "max_tokens": 1000', # Missing closing brace
+ ]
+ )
+ print("✗ Invalid JSON should have failed but didn't")
+ except SystemExit:
+ print("✓ Invalid JSON dict parsing correctly rejected")
+ except Exception as e:
+ print(f"✓ Invalid JSON dict parsing correctly rejected: {e}")
+
+ # Test non-dict JSON parsing
+ try:
+ test_args = test_parser.parse_args(
+ [
+ "--openai-llm-reasoning",
+ '["not", "a", "dict"]', # Array instead of dict
+ ]
+ )
+ print("✗ Non-dict JSON should have failed but didn't")
+ except SystemExit:
+ print("✓ Non-dict JSON parsing correctly rejected")
+ except Exception as e:
+ print(f"✓ Non-dict JSON parsing correctly rejected: {e}")
+
+ print("\n" + "=" * 50)
+ print("TESTING ENVIRONMENT VARIABLE SUPPORT")
+ print("=" * 50)
+
+ # Test environment variable support for dict
+ import os
+
+ os.environ["OPENAI_LLM_REASONING"] = (
+ '{"effort": "high", "max_tokens": 3000, "exclude": false}'
+ )
+
+ env_parser = ArgumentParser(description="Test env var dict parsing")
+ OpenAILLMOptions.add_args(env_parser)
+
+ try:
+ env_args = env_parser.parse_args(
+ []
+ ) # No command line args, should use env var
+ reasoning_from_env = OpenAILLMOptions.options_dict(env_args).get(
+ "reasoning"
+ )
+ if reasoning_from_env:
+ print("✓ Environment variable dict parsing successful:")
+ print(f" Parsed reasoning from env: {reasoning_from_env}")
+ else:
+ print("✗ Environment variable dict parsing failed: No reasoning found")
+ except Exception as e:
+ print(f"✗ Environment variable dict parsing failed: {e}")
+ finally:
+ # Clean up environment variable
+ if "OPENAI_LLM_REASONING" in os.environ:
+ del os.environ["OPENAI_LLM_REASONING"]
+
else:
print(BindingOptions.generate_dot_env_sample())
diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py
index 357b65bf..39b64ce3 100644
--- a/lightrag/llm/lollms.py
+++ b/lightrag/llm/lollms.py
@@ -59,7 +59,7 @@ async def lollms_model_if_cache(
"personality": kwargs.get("personality", -1),
"n_predict": kwargs.get("n_predict", None),
"stream": stream,
- "temperature": kwargs.get("temperature", 0.8),
+ "temperature": kwargs.get("temperature", 1.0),
"top_k": kwargs.get("top_k", 50),
"top_p": kwargs.get("top_p", 0.95),
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py
index 1ca5504e..6423fa90 100644
--- a/lightrag/llm/ollama.py
+++ b/lightrag/llm/ollama.py
@@ -51,6 +51,8 @@ async def _ollama_model_if_cache(
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
+ if timeout == 0:
+ timeout = None
kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None)
headers = {
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index 910d1812..3bd652f4 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -149,18 +149,20 @@ async def openai_complete_if_cache(
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("openai").setLevel(logging.INFO)
+ # Remove special kwargs that shouldn't be passed to OpenAI
+ kwargs.pop("hashing_kv", None)
+ kwargs.pop("keyword_extraction", None)
+
# Extract client configuration options
client_configs = kwargs.pop("openai_client_configs", {})
# Create the OpenAI client
openai_async_client = create_openai_async_client(
- api_key=api_key, base_url=base_url, client_configs=client_configs
+ api_key=api_key,
+ base_url=base_url,
+ client_configs=client_configs,
)
- # Remove special kwargs that shouldn't be passed to OpenAI
- kwargs.pop("hashing_kv", None)
- kwargs.pop("keyword_extraction", None)
-
# Prepare messages
messages: list[dict[str, Any]] = []
if system_prompt:
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 2899b1fb..22ed9117 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -11,11 +11,10 @@ from collections import Counter, defaultdict
from .utils import (
logger,
- clean_str,
compute_mdhash_id,
Tokenizer,
is_float_regex,
- normalize_extracted_info,
+ sanitize_and_normalize_extracted_text,
pack_user_ass_to_openai_messages,
split_string_by_multi_markers,
truncate_list_by_token_size,
@@ -31,6 +30,7 @@ from .utils import (
pick_by_vector_similarity,
process_chunks_unified,
build_file_path,
+ safe_vdb_operation_with_exception,
)
from .base import (
BaseGraphStorage,
@@ -47,6 +47,8 @@ from .constants import (
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_RELATED_CHUNK_NUMBER,
DEFAULT_KG_CHUNK_PICK_METHOD,
+ DEFAULT_ENTITY_TYPES,
+ DEFAULT_SUMMARY_LANGUAGE,
)
from .kg.shared_storage import get_storage_keyed_lock
import time
@@ -114,48 +116,195 @@ def chunking_by_token_size(
async def _handle_entity_relation_summary(
+ description_type: str,
entity_or_relation_name: str,
- description: str,
+ description_list: list[str],
+ seperator: str,
+ global_config: dict,
+ llm_response_cache: BaseKVStorage | None = None,
+) -> tuple[str, bool]:
+ """Handle entity relation description summary using map-reduce approach.
+
+ This function summarizes a list of descriptions using a map-reduce strategy:
+ 1. If total tokens < summary_context_size and len(description_list) < force_llm_summary_on_merge, no need to summarize
+ 2. If total tokens < summary_max_tokens, summarize with LLM directly
+ 3. Otherwise, split descriptions into chunks that fit within token limits
+ 4. Summarize each chunk, then recursively process the summaries
+ 5. Continue until we get a final summary within token limits or num of descriptions is less than force_llm_summary_on_merge
+
+ Args:
+ entity_or_relation_name: Name of the entity or relation being summarized
+ description_list: List of description strings to summarize
+ global_config: Global configuration containing tokenizer and limits
+ llm_response_cache: Optional cache for LLM responses
+
+ Returns:
+ Tuple of (final_summarized_description_string, llm_was_used_boolean)
+ """
+ # Handle empty input
+ if not description_list:
+ return "", False
+
+ # If only one description, return it directly (no need for LLM call)
+ if len(description_list) == 1:
+ return description_list[0], False
+
+ # Get configuration
+ tokenizer: Tokenizer = global_config["tokenizer"]
+ summary_context_size = global_config["summary_context_size"]
+ summary_max_tokens = global_config["summary_max_tokens"]
+ force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
+
+ current_list = description_list[:] # Copy the list to avoid modifying original
+ llm_was_used = False # Track whether LLM was used during the entire process
+
+ # Iterative map-reduce process
+ while True:
+ # Calculate total tokens in current list
+ total_tokens = sum(len(tokenizer.encode(desc)) for desc in current_list)
+
+ # If total length is within limits, perform final summarization
+ if total_tokens <= summary_context_size or len(current_list) <= 2:
+ if (
+ len(current_list) < force_llm_summary_on_merge
+ and total_tokens < summary_max_tokens
+ ):
+ # no LLM needed, just join the descriptions
+ final_description = seperator.join(current_list)
+ return final_description if final_description else "", llm_was_used
+ else:
+ if total_tokens > summary_context_size and len(current_list) <= 2:
+ logger.warning(
+ f"Summarizing {entity_or_relation_name}: Oversize descpriton found"
+ )
+ # Final summarization of remaining descriptions - LLM will be used
+ final_summary = await _summarize_descriptions(
+ description_type,
+ entity_or_relation_name,
+ current_list,
+ global_config,
+ llm_response_cache,
+ )
+ return final_summary, True # LLM was used for final summarization
+
+ # Need to split into chunks - Map phase
+ # Ensure each chunk has minimum 2 descriptions to guarantee progress
+ chunks = []
+ current_chunk = []
+ current_tokens = 0
+
+ # Currently least 3 descriptions in current_list
+ for i, desc in enumerate(current_list):
+ desc_tokens = len(tokenizer.encode(desc))
+
+ # If adding current description would exceed limit, finalize current chunk
+ if current_tokens + desc_tokens > summary_context_size and current_chunk:
+ # Ensure we have at least 2 descriptions in the chunk (when possible)
+ if len(current_chunk) == 1:
+ # Force add one more description to ensure minimum 2 per chunk
+ current_chunk.append(desc)
+ chunks.append(current_chunk)
+ logger.warning(
+ f"Summarizing {entity_or_relation_name}: Oversize descpriton found"
+ )
+ current_chunk = [] # next group is empty
+ current_tokens = 0
+ else: # curren_chunk is ready for summary in reduce phase
+ chunks.append(current_chunk)
+ current_chunk = [desc] # leave it for next group
+ current_tokens = desc_tokens
+ else:
+ current_chunk.append(desc)
+ current_tokens += desc_tokens
+
+ # Add the last chunk if it exists
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ logger.info(
+ f" Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups"
+ )
+
+ # Reduce phase: summarize each group from chunks
+ new_summaries = []
+ for chunk in chunks:
+ if len(chunk) == 1:
+ # Optimization: single description chunks don't need LLM summarization
+ new_summaries.append(chunk[0])
+ else:
+ # Multiple descriptions need LLM summarization
+ summary = await _summarize_descriptions(
+ description_type,
+ entity_or_relation_name,
+ chunk,
+ global_config,
+ llm_response_cache,
+ )
+ new_summaries.append(summary)
+ llm_was_used = True # Mark that LLM was used in reduce phase
+
+ # Update current list with new summaries for next iteration
+ current_list = new_summaries
+
+
+async def _summarize_descriptions(
+ description_type: str,
+ description_name: str,
+ description_list: list[str],
global_config: dict,
llm_response_cache: BaseKVStorage | None = None,
) -> str:
- """Handle entity relation summary
- For each entity or relation, input is the combined description of already existing description and new description.
- If too long, use LLM to summarize.
+ """Helper function to summarize a list of descriptions using LLM.
+
+ Args:
+ entity_or_relation_name: Name of the entity or relation being summarized
+ descriptions: List of description strings to summarize
+ global_config: Global configuration containing LLM function and settings
+ llm_response_cache: Optional cache for LLM responses
+
+ Returns:
+ Summarized description string
"""
use_llm_func: callable = global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
- tokenizer: Tokenizer = global_config["tokenizer"]
- llm_max_tokens = global_config["summary_max_tokens"]
+ language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE)
- language = global_config["addon_params"].get(
- "language", PROMPTS["DEFAULT_LANGUAGE"]
- )
-
- tokens = tokenizer.encode(description)
-
- ### summarize is not determined here anymore (It's determined by num_fragment now)
- # if len(tokens) < summary_max_tokens: # No need for summary
- # return description
+ summary_length_recommended = global_config["summary_length_recommended"]
prompt_template = PROMPTS["summarize_entity_descriptions"]
- use_description = tokenizer.decode(tokens[:llm_max_tokens])
+
+ # Join descriptions and apply token-based truncation if necessary
+ joined_descriptions = "\n\n".join(description_list)
+ tokenizer = global_config["tokenizer"]
+ summary_context_size = global_config["summary_context_size"]
+
+ # Token-based truncation to ensure input fits within limits
+ tokens = tokenizer.encode(joined_descriptions)
+ if len(tokens) > summary_context_size:
+ truncated_tokens = tokens[:summary_context_size]
+ joined_descriptions = tokenizer.decode(truncated_tokens)
+
+ # Prepare context for the prompt
context_base = dict(
- entity_name=entity_or_relation_name,
- description_list=use_description.split(GRAPH_FIELD_SEP),
+ description_type=description_type,
+ description_name=description_name,
+ description_list=joined_descriptions,
+ summary_length=summary_length_recommended,
language=language,
)
use_prompt = prompt_template.format(**context_base)
- logger.debug(f"Trigger summary: {entity_or_relation_name}")
+
+ logger.debug(
+ f"Summarizing {len(description_list)} descriptions for: {description_name}"
+ )
# Use LLM function with cache (higher priority for summary generation)
summary = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
- # max_tokens=summary_max_tokens,
cache_type="extract",
)
return summary
@@ -166,112 +315,148 @@ async def _handle_single_entity_extraction(
chunk_key: str,
file_path: str = "unknown_source",
):
- if len(record_attributes) < 4 or '"entity"' not in record_attributes[0]:
+ if len(record_attributes) < 4 or "entity" not in record_attributes[0]:
+ if len(record_attributes) > 1 and "entity" in record_attributes[0]:
+ logger.warning(
+ f"Entity extraction failed in {chunk_key}: expecting 4 fields but got {len(record_attributes)}"
+ )
+ logger.warning(f"Entity extracted: {record_attributes[1]}")
return None
- # Clean and validate entity name
- entity_name = clean_str(record_attributes[1]).strip()
- if not entity_name:
- logger.warning(
- f"Entity extraction error: empty entity name in: {record_attributes}"
+ try:
+ entity_name = sanitize_and_normalize_extracted_text(
+ record_attributes[1], remove_inner_quotes=True
+ )
+
+ # Validate entity name after all cleaning steps
+ if not entity_name or not entity_name.strip():
+ logger.warning(
+ f"Entity extraction error: entity name became empty after cleaning. Original: '{record_attributes[1]}'"
+ )
+ return None
+
+ # Process entity type with same cleaning pipeline
+ entity_type = sanitize_and_normalize_extracted_text(
+ record_attributes[2], remove_inner_quotes=True
+ )
+
+ if not entity_type.strip() or any(
+ char in entity_type for char in ["'", "(", ")", "<", ">", "|", "/", "\\"]
+ ):
+ logger.warning(
+ f"Entity extraction error: invalid entity type in: {record_attributes}"
+ )
+ return None
+
+ # Captitalize first letter of entity_type
+ entity_type = entity_type.title()
+
+ # Process entity description with same cleaning pipeline
+ entity_description = sanitize_and_normalize_extracted_text(record_attributes[3])
+
+ if not entity_description.strip():
+ logger.warning(
+ f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'"
+ )
+ return None
+
+ return dict(
+ entity_name=entity_name,
+ entity_type=entity_type,
+ description=entity_description,
+ source_id=chunk_key,
+ file_path=file_path,
+ )
+
+ except ValueError as e:
+ logger.error(
+ f"Entity extraction failed due to encoding issues in chunk {chunk_key}: {e}"
)
return None
-
- # Normalize entity name
- entity_name = normalize_extracted_info(entity_name, is_entity=True)
-
- # Check if entity name became empty after normalization
- if not entity_name or not entity_name.strip():
- logger.warning(
- f"Entity extraction error: entity name became empty after normalization. Original: '{record_attributes[1]}'"
+ except Exception as e:
+ logger.error(
+ f"Entity extraction failed with unexpected error in chunk {chunk_key}: {e}"
)
return None
- # Clean and validate entity type
- entity_type = clean_str(record_attributes[2]).strip('"')
- if not entity_type.strip() or entity_type.startswith('("'):
- logger.warning(
- f"Entity extraction error: invalid entity type in: {record_attributes}"
- )
- return None
-
- # Clean and validate description
- entity_description = clean_str(record_attributes[3])
- entity_description = normalize_extracted_info(entity_description)
-
- if not entity_description.strip():
- logger.warning(
- f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'"
- )
- return None
-
- return dict(
- entity_name=entity_name,
- entity_type=entity_type,
- description=entity_description,
- source_id=chunk_key,
- file_path=file_path,
- )
-
async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
file_path: str = "unknown_source",
):
- if len(record_attributes) < 5 or '"relationship"' not in record_attributes[0]:
+ if len(record_attributes) < 5 or "relationship" not in record_attributes[0]:
+ if len(record_attributes) > 1 and "relationship" in record_attributes[0]:
+ logger.warning(
+ f"Relation extraction failed in {chunk_key}: expecting 5 fields but got {len(record_attributes)}"
+ )
+ logger.warning(f"Relation extracted: {record_attributes[1]}")
return None
- # add this record as edge
- source = clean_str(record_attributes[1])
- target = clean_str(record_attributes[2])
- # Normalize source and target entity names
- source = normalize_extracted_info(source, is_entity=True)
- target = normalize_extracted_info(target, is_entity=True)
+ try:
+ source = sanitize_and_normalize_extracted_text(
+ record_attributes[1], remove_inner_quotes=True
+ )
+ target = sanitize_and_normalize_extracted_text(
+ record_attributes[2], remove_inner_quotes=True
+ )
- # Check if source or target became empty after normalization
- if not source or not source.strip():
+ # Validate entity names after all cleaning steps
+ if not source:
+ logger.warning(
+ f"Relationship extraction error: source entity became empty after cleaning. Original: '{record_attributes[1]}'"
+ )
+ return None
+
+ if not target:
+ logger.warning(
+ f"Relationship extraction error: target entity became empty after cleaning. Original: '{record_attributes[2]}'"
+ )
+ return None
+
+ if source == target:
+ logger.debug(
+ f"Relationship source and target are the same in: {record_attributes}"
+ )
+ return None
+
+ # Process keywords with same cleaning pipeline
+ edge_keywords = sanitize_and_normalize_extracted_text(
+ record_attributes[3], remove_inner_quotes=True
+ )
+ edge_keywords = edge_keywords.replace(",", ",")
+
+ # Process relationship description with same cleaning pipeline
+ edge_description = sanitize_and_normalize_extracted_text(record_attributes[4])
+
+ edge_source_id = chunk_key
+ weight = (
+ float(record_attributes[-1].strip('"').strip("'"))
+ if is_float_regex(record_attributes[-1].strip('"').strip("'"))
+ else 1.0
+ )
+
+ return dict(
+ src_id=source,
+ tgt_id=target,
+ weight=weight,
+ description=edge_description,
+ keywords=edge_keywords,
+ source_id=edge_source_id,
+ file_path=file_path,
+ )
+
+ except ValueError as e:
logger.warning(
- f"Relationship extraction error: source entity became empty after normalization. Original: '{record_attributes[1]}'"
+ f"Relationship extraction failed due to encoding issues in chunk {chunk_key}: {e}"
)
return None
-
- if not target or not target.strip():
+ except Exception as e:
logger.warning(
- f"Relationship extraction error: target entity became empty after normalization. Original: '{record_attributes[2]}'"
+ f"Relationship extraction failed with unexpected error in chunk {chunk_key}: {e}"
)
return None
- if source == target:
- logger.debug(
- f"Relationship source and target are the same in: {record_attributes}"
- )
- return None
-
- edge_description = clean_str(record_attributes[3])
- edge_description = normalize_extracted_info(edge_description)
-
- edge_keywords = normalize_extracted_info(
- clean_str(record_attributes[4]), is_entity=True
- )
- edge_keywords = edge_keywords.replace(",", ",")
-
- edge_source_id = chunk_key
- weight = (
- float(record_attributes[-1].strip('"').strip("'"))
- if is_float_regex(record_attributes[-1].strip('"').strip("'"))
- else 1.0
- )
- return dict(
- src_id=source,
- tgt_id=target,
- weight=weight,
- description=edge_description,
- keywords=edge_keywords,
- source_id=edge_source_id,
- file_path=file_path,
- )
-
async def _rebuild_knowledge_from_chunks(
entities_to_rebuild: dict[str, set[str]],
@@ -413,7 +598,7 @@ async def _rebuild_knowledge_from_chunks(
)
rebuilt_entities_count += 1
status_message = (
- f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks"
+ f"Rebuilt `{entity_name}` from {len(chunk_ids)} chunks"
)
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
@@ -422,7 +607,7 @@ async def _rebuild_knowledge_from_chunks(
pipeline_status["history_messages"].append(status_message)
except Exception as e:
failed_entities_count += 1
- status_message = f"Failed to rebuild entity {entity_name}: {e}"
+ status_message = f"Failed to rebuild `{entity_name}`: {e}"
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
@@ -453,7 +638,9 @@ async def _rebuild_knowledge_from_chunks(
global_config=global_config,
)
rebuilt_relationships_count += 1
- status_message = f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks"
+ status_message = (
+ f"Rebuilt `{src} - {tgt}` from {len(chunk_ids)} chunks"
+ )
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
@@ -461,7 +648,7 @@ async def _rebuild_knowledge_from_chunks(
pipeline_status["history_messages"].append(status_message)
except Exception as e:
failed_relationships_count += 1
- status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}"
+ status_message = f"Failed to rebuild `{src} - {tgt}`: {e}"
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
@@ -525,14 +712,20 @@ async def _get_cached_extraction_results(
) -> dict[str, list[str]]:
"""Get cached extraction results for specific chunk IDs
+ This function retrieves cached LLM extraction results for the given chunk IDs and returns
+ them sorted by creation time. The results are sorted at two levels:
+ 1. Individual extraction results within each chunk are sorted by create_time (earliest first)
+ 2. Chunks themselves are sorted by the create_time of their earliest extraction result
+
Args:
llm_response_cache: LLM response cache storage
chunk_ids: Set of chunk IDs to get cached results for
- text_chunks_data: Pre-loaded chunk data (optional, for performance)
- text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
+ text_chunks_storage: Text chunks storage for retrieving chunk data and LLM cache references
Returns:
- Dict mapping chunk_id -> list of extraction_result_text
+ Dict mapping chunk_id -> list of extraction_result_text, where:
+ - Keys (chunk_ids) are ordered by the create_time of their first extraction result
+ - Values (extraction results) are ordered by create_time within each chunk
"""
cached_results = {}
@@ -541,15 +734,13 @@ async def _get_cached_extraction_results(
# Read from storage
chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
- for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
+ for chunk_data in chunk_data_list:
if chunk_data and isinstance(chunk_data, dict):
llm_cache_list = chunk_data.get("llm_cache_list", [])
if llm_cache_list:
all_cache_ids.update(llm_cache_list)
else:
- logger.warning(
- f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
- )
+ logger.warning(f"Chunk data is invalid or None: {chunk_data}")
if not all_cache_ids:
logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
@@ -560,7 +751,7 @@ async def _get_cached_extraction_results(
# Process cache entries and group by chunk_id
valid_entries = 0
- for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
+ for cache_entry in cache_data_list:
if (
cache_entry is not None
and isinstance(cache_entry, dict)
@@ -580,16 +771,111 @@ async def _get_cached_extraction_results(
# Store tuple with extraction result and creation time for sorting
cached_results[chunk_id].append((extraction_result, create_time))
- # Sort extraction results by create_time for each chunk
+ # Sort extraction results by create_time for each chunk and collect earliest times
+ chunk_earliest_times = {}
for chunk_id in cached_results:
# Sort by create_time (x[1]), then extract only extraction_result (x[0])
cached_results[chunk_id].sort(key=lambda x: x[1])
+ # Store the earliest create_time for this chunk (first item after sorting)
+ chunk_earliest_times[chunk_id] = cached_results[chunk_id][0][1]
+ # Extract only extraction_result (x[0])
cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
- logger.info(
- f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
+ # Sort cached_results by the earliest create_time of each chunk
+ sorted_chunk_ids = sorted(
+ chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id]
)
- return cached_results
+
+ # Rebuild cached_results in sorted order
+ sorted_cached_results = {}
+ for chunk_id in sorted_chunk_ids:
+ sorted_cached_results[chunk_id] = cached_results[chunk_id]
+
+ logger.info(
+ f"Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results"
+ )
+ return sorted_cached_results
+
+
+async def _process_extraction_result(
+ result: str,
+ chunk_key: str,
+ file_path: str = "unknown_source",
+ tuple_delimiter: str = "<|>",
+ record_delimiter: str = "##",
+ completion_delimiter: str = "<|COMPLETE|>",
+) -> tuple[dict, dict]:
+ """Process a single extraction result (either initial or gleaning)
+ Args:
+ result (str): The extraction result to process
+ chunk_key (str): The chunk key for source tracking
+ file_path (str): The file path for citation
+ tuple_delimiter (str): Delimiter for tuple fields
+ record_delimiter (str): Delimiter for records
+ completion_delimiter (str): Delimiter for completion
+ Returns:
+ tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
+ """
+ maybe_nodes = defaultdict(list)
+ maybe_edges = defaultdict(list)
+
+ # Standardize Chinese brackets around record_delimiter to English brackets
+ bracket_pattern = f"[))](\\s*{re.escape(record_delimiter)}\\s*)[((]"
+ result = re.sub(bracket_pattern, ")\\1(", result)
+
+ records = split_string_by_multi_markers(
+ result,
+ [record_delimiter, completion_delimiter],
+ )
+
+ for record in records:
+ # Remove outer brackets (support English and Chinese brackets)
+ record = record.strip()
+ if record.startswith("(") or record.startswith("("):
+ record = record[1:]
+ if record.endswith(")") or record.endswith(")"):
+ record = record[:-1]
+
+ record = record.strip()
+ if record is None:
+ continue
+
+ if tuple_delimiter == "<|>":
+ # fix entity<| with entity<|>
+ record = re.sub(r"^entity<\|(?!>)", r"entity<|>", record)
+ # fix relationship<| with relationship<|>
+ record = re.sub(r"^relationship<\|(?!>)", r"relationship<|>", record)
+ # fix <||> with <|>
+ record = record.replace("<||>", "<|>")
+ # fix < | > with <|>
+ record = record.replace("< | >", "<|>")
+ # fix <<|>> with <|>
+ record = record.replace("<<|>>", "<|>")
+ # fix <|>> with <|>
+ record = record.replace("<|>>", "<|>")
+ # fix <<|> with <|>
+ record = record.replace("<<|>", "<|>")
+
+ record_attributes = split_string_by_multi_markers(record, [tuple_delimiter])
+
+ # Try to parse as entity
+ entity_data = await _handle_single_entity_extraction(
+ record_attributes, chunk_key, file_path
+ )
+ if entity_data is not None:
+ maybe_nodes[entity_data["entity_name"]].append(entity_data)
+ continue
+
+ # Try to parse as relationship
+ relationship_data = await _handle_single_relationship_extraction(
+ record_attributes, chunk_key, file_path
+ )
+ if relationship_data is not None:
+ maybe_edges[
+ (relationship_data["src_id"], relationship_data["tgt_id"])
+ ].append(relationship_data)
+
+ return dict(maybe_nodes), dict(maybe_edges)
async def _parse_extraction_result(
@@ -613,46 +899,16 @@ async def _parse_extraction_result(
if chunk_data
else "unknown_source"
)
- context_base = dict(
+
+ # Call the shared processing function
+ return await _process_extraction_result(
+ extraction_result,
+ chunk_id,
+ file_path,
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
)
- maybe_nodes = defaultdict(list)
- maybe_edges = defaultdict(list)
-
- # Parse the extraction result using the same logic as in extract_entities
- records = split_string_by_multi_markers(
- extraction_result,
- [context_base["record_delimiter"], context_base["completion_delimiter"]],
- )
- for record in records:
- record = re.search(r"\((.*)\)", record)
- if record is None:
- continue
- record = record.group(1)
- record_attributes = split_string_by_multi_markers(
- record, [context_base["tuple_delimiter"]]
- )
-
- # Try to parse as entity
- entity_data = await _handle_single_entity_extraction(
- record_attributes, chunk_id, file_path
- )
- if entity_data is not None:
- maybe_nodes[entity_data["entity_name"]].append(entity_data)
- continue
-
- # Try to parse as relationship
- relationship_data = await _handle_single_relationship_extraction(
- record_attributes, chunk_id, file_path
- )
- if relationship_data is not None:
- maybe_edges[
- (relationship_data["src_id"], relationship_data["tgt_id"])
- ].append(relationship_data)
-
- return dict(maybe_nodes), dict(maybe_edges)
async def _rebuild_single_entity(
@@ -675,33 +931,24 @@ async def _rebuild_single_entity(
async def _update_entity_storage(
final_description: str, entity_type: str, file_paths: set[str]
):
- # Update entity in graph storage
- updated_entity_data = {
- **current_entity,
- "description": final_description,
- "entity_type": entity_type,
- "source_id": GRAPH_FIELD_SEP.join(chunk_ids),
- "file_path": GRAPH_FIELD_SEP.join(file_paths)
- if file_paths
- else current_entity.get("file_path", "unknown_source"),
- }
- await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data)
-
- # Update entity in vector database
- entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-")
-
- # Delete old vector record first
try:
- await entities_vdb.delete([entity_vdb_id])
- except Exception as e:
- logger.debug(
- f"Could not delete old entity vector record {entity_vdb_id}: {e}"
- )
+ # Update entity in graph storage (critical path)
+ updated_entity_data = {
+ **current_entity,
+ "description": final_description,
+ "entity_type": entity_type,
+ "source_id": GRAPH_FIELD_SEP.join(chunk_ids),
+ "file_path": GRAPH_FIELD_SEP.join(file_paths)
+ if file_paths
+ else current_entity.get("file_path", "unknown_source"),
+ }
+ await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data)
- # Insert new vector record
- entity_content = f"{entity_name}\n{final_description}"
- await entities_vdb.upsert(
- {
+ # Update entity in vector database (equally critical)
+ entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-")
+ entity_content = f"{entity_name}\n{final_description}"
+
+ vdb_data = {
entity_vdb_id: {
"content": entity_content,
"entity_name": entity_name,
@@ -711,22 +958,20 @@ async def _rebuild_single_entity(
"file_path": updated_entity_data["file_path"],
}
}
- )
- # Helper function to generate final description with optional LLM summary
- async def _generate_final_description(combined_description: str) -> str:
- force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
- num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
-
- if num_fragment >= force_llm_summary_on_merge:
- return await _handle_entity_relation_summary(
- entity_name,
- combined_description,
- global_config,
- llm_response_cache=llm_response_cache,
+ # Use safe operation wrapper - VDB failure must throw exception
+ await safe_vdb_operation_with_exception(
+ operation=lambda: entities_vdb.upsert(vdb_data),
+ operation_name="rebuild_entity_upsert",
+ entity_name=entity_name,
+ max_retries=3,
+ retry_delay=0.1,
)
- else:
- return combined_description
+
+ except Exception as e:
+ error_msg = f"Failed to update entity storage for `{entity_name}`: {e}"
+ logger.error(error_msg)
+ raise # Re-raise exception
# Collect all entity data from relevant chunks
all_entity_data = []
@@ -736,13 +981,13 @@ async def _rebuild_single_entity(
if not all_entity_data:
logger.warning(
- f"No cached entity data found for {entity_name}, trying to rebuild from relationships"
+ f"No entity data found for `{entity_name}`, trying to rebuild from relationships"
)
# Get all edges connected to this entity
edges = await knowledge_graph_inst.get_node_edges(entity_name)
if not edges:
- logger.warning(f"No relationships found for entity {entity_name}")
+ logger.warning(f"No relations attached to entity `{entity_name}`")
return
# Collect relationship data to extract entity information
@@ -760,10 +1005,19 @@ async def _rebuild_single_entity(
edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP)
file_paths.update(edge_file_paths)
- # Generate description from relationships or fallback to current
- if relationship_descriptions:
- combined_description = GRAPH_FIELD_SEP.join(relationship_descriptions)
- final_description = await _generate_final_description(combined_description)
+ # deduplicate descriptions
+ description_list = list(dict.fromkeys(relationship_descriptions))
+
+ # Generate final description from relationships or fallback to current
+ if description_list:
+ final_description, _ = await _handle_entity_relation_summary(
+ "Entity",
+ entity_name,
+ description_list,
+ GRAPH_FIELD_SEP,
+ global_config,
+ llm_response_cache=llm_response_cache,
+ )
else:
final_description = current_entity.get("description", "")
@@ -784,12 +1038,9 @@ async def _rebuild_single_entity(
if entity_data.get("file_path"):
file_paths.add(entity_data["file_path"])
- # Combine all descriptions
- combined_description = (
- GRAPH_FIELD_SEP.join(descriptions)
- if descriptions
- else current_entity.get("description", "")
- )
+ # Remove duplicates while preserving order
+ description_list = list(dict.fromkeys(descriptions))
+ entity_types = list(dict.fromkeys(entity_types))
# Get most common entity type
entity_type = (
@@ -798,8 +1049,19 @@ async def _rebuild_single_entity(
else current_entity.get("entity_type", "UNKNOWN")
)
- # Generate final description and update storage
- final_description = await _generate_final_description(combined_description)
+ # Generate final description from entities or fallback to current
+ if description_list:
+ final_description, _ = await _handle_entity_relation_summary(
+ "Entity",
+ entity_name,
+ description_list,
+ GRAPH_FIELD_SEP,
+ global_config,
+ llm_response_cache=llm_response_cache,
+ )
+ else:
+ final_description = current_entity.get("description", "")
+
await _update_entity_storage(final_description, entity_type, file_paths)
@@ -836,7 +1098,7 @@ async def _rebuild_single_relationship(
)
if not all_relationship_data:
- logger.warning(f"No cached relationship data found for {src}-{tgt}")
+ logger.warning(f"No relation data found for `{src}-{tgt}`")
return
# Merge descriptions and keywords
@@ -855,42 +1117,38 @@ async def _rebuild_single_relationship(
if rel_data.get("file_path"):
file_paths.add(rel_data["file_path"])
- # Combine descriptions and keywords
- combined_description = (
- GRAPH_FIELD_SEP.join(descriptions)
- if descriptions
- else current_relationship.get("description", "")
- )
+ # Remove duplicates while preserving order
+ description_list = list(dict.fromkeys(descriptions))
+ keywords = list(dict.fromkeys(keywords))
+
combined_keywords = (
", ".join(set(keywords))
if keywords
else current_relationship.get("keywords", "")
)
- # weight = (
- # sum(weights) / len(weights)
- # if weights
- # else current_relationship.get("weight", 1.0)
- # )
+
weight = sum(weights) if weights else current_relationship.get("weight", 1.0)
- # Use summary if description has too many fragments
- force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
- num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
-
- if num_fragment >= force_llm_summary_on_merge:
- final_description = await _handle_entity_relation_summary(
+ # Generate final description from relations or fallback to current
+ if description_list:
+ final_description, _ = await _handle_entity_relation_summary(
+ "Relation",
f"{src}-{tgt}",
- combined_description,
+ description_list,
+ GRAPH_FIELD_SEP,
global_config,
llm_response_cache=llm_response_cache,
)
else:
- final_description = combined_description
+ # fallback to keep current(unchanged)
+ final_description = current_relationship.get("description", "")
# Update relationship in graph storage
updated_relationship_data = {
**current_relationship,
- "description": final_description,
+ "description": final_description
+ if final_description
+ else current_relationship.get("description", ""),
"keywords": combined_keywords,
"weight": weight,
"source_id": GRAPH_FIELD_SEP.join(chunk_ids),
@@ -901,21 +1159,21 @@ async def _rebuild_single_relationship(
await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data)
# Update relationship in vector database
- rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-")
- rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-")
-
- # Delete old vector records first (both directions to be safe)
try:
- await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse])
- except Exception as e:
- logger.debug(
- f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}"
- )
+ rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-")
+ rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-")
- # Insert new vector record
- rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}"
- await relationships_vdb.upsert(
- {
+ # Delete old vector records first (both directions to be safe)
+ try:
+ await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse])
+ except Exception as e:
+ logger.debug(
+ f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}"
+ )
+
+ # Insert new vector record
+ rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}"
+ vdb_data = {
rel_vdb_id: {
"src_id": src,
"tgt_id": tgt,
@@ -927,7 +1185,20 @@ async def _rebuild_single_relationship(
"file_path": updated_relationship_data["file_path"],
}
}
- )
+
+ # Use safe operation wrapper - VDB failure must throw exception
+ await safe_vdb_operation_with_exception(
+ operation=lambda: relationships_vdb.upsert(vdb_data),
+ operation_name="rebuild_relationship_upsert",
+ entity_name=f"{src}-{tgt}",
+ max_retries=3,
+ retry_delay=0.2,
+ )
+
+ except Exception as e:
+ error_msg = f"Failed to rebuild relationship storage for `{src}-{tgt}`: {e}"
+ logger.error(error_msg)
+ raise # Re-raise exception
async def _merge_nodes_then_upsert(
@@ -948,13 +1219,9 @@ async def _merge_nodes_then_upsert(
already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node:
already_entity_types.append(already_node["entity_type"])
- already_source_ids.extend(
- split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
- )
- already_file_paths.extend(
- split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP])
- )
- already_description.append(already_node["description"])
+ already_source_ids.extend(already_node["source_id"].split(GRAPH_FIELD_SEP))
+ already_file_paths.extend(already_node["file_path"].split(GRAPH_FIELD_SEP))
+ already_description.extend(already_node["description"].split(GRAPH_FIELD_SEP))
entity_type = sorted(
Counter(
@@ -962,42 +1229,54 @@ async def _merge_nodes_then_upsert(
).items(),
key=lambda x: x[1],
reverse=True,
- )[0][0]
- description = GRAPH_FIELD_SEP.join(
- sorted(set([dp["description"] for dp in nodes_data] + already_description))
+ )[0][0] # Get the entity type with the highest count
+
+ # merge and deduplicate description
+ description_list = list(
+ dict.fromkeys(
+ already_description
+ + [dp["description"] for dp in nodes_data if dp.get("description")]
+ )
)
+
+ num_fragment = len(description_list)
+ already_fragment = len(already_description)
+ deduplicated_num = already_fragment + len(nodes_data) - num_fragment
+ if deduplicated_num > 0:
+ dd_message = f"(dd:{deduplicated_num})"
+ else:
+ dd_message = ""
+ if num_fragment > 0:
+ # Get summary and LLM usage status
+ description, llm_was_used = await _handle_entity_relation_summary(
+ "Entity",
+ entity_name,
+ description_list,
+ GRAPH_FIELD_SEP,
+ global_config,
+ llm_response_cache,
+ )
+
+ # Log based on actual LLM usage
+ if llm_was_used:
+ status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
+ else:
+ status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
+
+ logger.info(status_message)
+ if pipeline_status is not None and pipeline_status_lock is not None:
+ async with pipeline_status_lock:
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
+ else:
+ logger.error(f"Entity {entity_name} has no description")
+ description = "(no description)"
+
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
)
file_path = build_file_path(already_file_paths, nodes_data, entity_name)
- force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
-
- num_fragment = description.count(GRAPH_FIELD_SEP) + 1
- num_new_fragment = len(set([dp["description"] for dp in nodes_data]))
-
- if num_fragment > 1:
- if num_fragment >= force_llm_summary_on_merge:
- status_message = f"LLM merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}"
- logger.info(status_message)
- if pipeline_status is not None and pipeline_status_lock is not None:
- async with pipeline_status_lock:
- pipeline_status["latest_message"] = status_message
- pipeline_status["history_messages"].append(status_message)
- description = await _handle_entity_relation_summary(
- entity_name,
- description,
- global_config,
- llm_response_cache,
- )
- else:
- status_message = f"Merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}"
- logger.info(status_message)
- if pipeline_status is not None and pipeline_status_lock is not None:
- async with pipeline_status_lock:
- pipeline_status["latest_message"] = status_message
- pipeline_status["history_messages"].append(status_message)
-
node_data = dict(
entity_id=entity_name,
entity_type=entity_type,
@@ -1044,22 +1323,20 @@ async def _merge_edges_then_upsert(
# Get source_id with empty string default if missing or None
if already_edge.get("source_id") is not None:
already_source_ids.extend(
- split_string_by_multi_markers(
- already_edge["source_id"], [GRAPH_FIELD_SEP]
- )
+ already_edge["source_id"].split(GRAPH_FIELD_SEP)
)
# Get file_path with empty string default if missing or None
if already_edge.get("file_path") is not None:
already_file_paths.extend(
- split_string_by_multi_markers(
- already_edge["file_path"], [GRAPH_FIELD_SEP]
- )
+ already_edge["file_path"].split(GRAPH_FIELD_SEP)
)
# Get description with empty string default if missing or None
if already_edge.get("description") is not None:
- already_description.append(already_edge["description"])
+ already_description.extend(
+ already_edge["description"].split(GRAPH_FIELD_SEP)
+ )
# Get keywords with empty string default if missing or None
if already_edge.get("keywords") is not None:
@@ -1071,15 +1348,47 @@ async def _merge_edges_then_upsert(
# Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
- description = GRAPH_FIELD_SEP.join(
- sorted(
- set(
- [dp["description"] for dp in edges_data if dp.get("description")]
- + already_description
- )
+
+ description_list = list(
+ dict.fromkeys(
+ already_description
+ + [dp["description"] for dp in edges_data if dp.get("description")]
)
)
+ num_fragment = len(description_list)
+ already_fragment = len(already_description)
+ deduplicated_num = already_fragment + len(edges_data) - num_fragment
+ if deduplicated_num > 0:
+ dd_message = f"(dd:{deduplicated_num})"
+ else:
+ dd_message = ""
+ if num_fragment > 0:
+ # Get summary and LLM usage status
+ description, llm_was_used = await _handle_entity_relation_summary(
+ "Relation",
+ f"({src_id}, {tgt_id})",
+ description_list,
+ GRAPH_FIELD_SEP,
+ global_config,
+ llm_response_cache,
+ )
+
+ # Log based on actual LLM usage
+ if llm_was_used:
+ status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
+ else:
+ status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}"
+
+ logger.info(status_message)
+ if pipeline_status is not None and pipeline_status_lock is not None:
+ async with pipeline_status_lock:
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
+ else:
+ logger.error(f"Edge {src_id} - {tgt_id} has no description")
+ description = "(no description)"
+
# Split all existing and new keywords into individual terms, then combine and deduplicate
all_keywords = set()
# Process already_keywords (which are comma-separated)
@@ -1127,35 +1436,6 @@ async def _merge_edges_then_upsert(
}
added_entities.append(entity_data)
- force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
-
- num_fragment = description.count(GRAPH_FIELD_SEP) + 1
- num_new_fragment = len(
- set([dp["description"] for dp in edges_data if dp.get("description")])
- )
-
- if num_fragment > 1:
- if num_fragment >= force_llm_summary_on_merge:
- status_message = f"LLM merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}"
- logger.info(status_message)
- if pipeline_status is not None and pipeline_status_lock is not None:
- async with pipeline_status_lock:
- pipeline_status["latest_message"] = status_message
- pipeline_status["history_messages"].append(status_message)
- description = await _handle_entity_relation_summary(
- f"({src_id}, {tgt_id})",
- description,
- global_config,
- llm_response_cache,
- )
- else:
- status_message = f"Merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}"
- logger.info(status_message)
- if pipeline_status is not None and pipeline_status_lock is not None:
- async with pipeline_status_lock:
- pipeline_status["latest_message"] = status_message
- pipeline_status["history_messages"].append(status_message)
-
await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
@@ -1263,27 +1543,68 @@ async def merge_nodes_and_edges(
async with get_storage_keyed_lock(
[entity_name], namespace=namespace, enable_logging=False
):
- entity_data = await _merge_nodes_then_upsert(
- entity_name,
- entities,
- knowledge_graph_inst,
- global_config,
- pipeline_status,
- pipeline_status_lock,
- llm_response_cache,
- )
- if entity_vdb is not None:
- data_for_vdb = {
- compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): {
- "entity_name": entity_data["entity_name"],
- "entity_type": entity_data["entity_type"],
- "content": f"{entity_data['entity_name']}\n{entity_data['description']}",
- "source_id": entity_data["source_id"],
- "file_path": entity_data.get("file_path", "unknown_source"),
+ try:
+ # Graph database operation (critical path, must succeed)
+ entity_data = await _merge_nodes_then_upsert(
+ entity_name,
+ entities,
+ knowledge_graph_inst,
+ global_config,
+ pipeline_status,
+ pipeline_status_lock,
+ llm_response_cache,
+ )
+
+ # Vector database operation (equally critical, must succeed)
+ if entity_vdb is not None and entity_data:
+ data_for_vdb = {
+ compute_mdhash_id(
+ entity_data["entity_name"], prefix="ent-"
+ ): {
+ "entity_name": entity_data["entity_name"],
+ "entity_type": entity_data["entity_type"],
+ "content": f"{entity_data['entity_name']}\n{entity_data['description']}",
+ "source_id": entity_data["source_id"],
+ "file_path": entity_data.get(
+ "file_path", "unknown_source"
+ ),
+ }
}
- }
- await entity_vdb.upsert(data_for_vdb)
- return entity_data
+
+ # Use safe operation wrapper - VDB failure must throw exception
+ await safe_vdb_operation_with_exception(
+ operation=lambda: entity_vdb.upsert(data_for_vdb),
+ operation_name="entity_upsert",
+ entity_name=entity_name,
+ max_retries=3,
+ retry_delay=0.1,
+ )
+
+ return entity_data
+
+ except Exception as e:
+ # Any database operation failure is critical
+ error_msg = (
+ f"Critical error in entity processing for `{entity_name}`: {e}"
+ )
+ logger.error(error_msg)
+
+ # Try to update pipeline status, but don't let status update failure affect main exception
+ try:
+ if (
+ pipeline_status is not None
+ and pipeline_status_lock is not None
+ ):
+ async with pipeline_status_lock:
+ pipeline_status["latest_message"] = error_msg
+ pipeline_status["history_messages"].append(error_msg)
+ except Exception as status_error:
+ logger.error(
+ f"Failed to update pipeline status: {status_error}"
+ )
+
+ # Re-raise the original exception
+ raise
# Create entity processing tasks
entity_tasks = []
@@ -1331,38 +1652,75 @@ async def merge_nodes_and_edges(
namespace=namespace,
enable_logging=False,
):
- added_entities = [] # Track entities added during edge processing
- edge_data = await _merge_edges_then_upsert(
- edge_key[0],
- edge_key[1],
- edges,
- knowledge_graph_inst,
- global_config,
- pipeline_status,
- pipeline_status_lock,
- llm_response_cache,
- added_entities, # Pass list to collect added entities
- )
+ try:
+ added_entities = [] # Track entities added during edge processing
- if edge_data is None:
- return None, []
+ # Graph database operation (critical path, must succeed)
+ edge_data = await _merge_edges_then_upsert(
+ edge_key[0],
+ edge_key[1],
+ edges,
+ knowledge_graph_inst,
+ global_config,
+ pipeline_status,
+ pipeline_status_lock,
+ llm_response_cache,
+ added_entities, # Pass list to collect added entities
+ )
- if relationships_vdb is not None:
- data_for_vdb = {
- compute_mdhash_id(
- edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"
- ): {
- "src_id": edge_data["src_id"],
- "tgt_id": edge_data["tgt_id"],
- "keywords": edge_data["keywords"],
- "content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
- "source_id": edge_data["source_id"],
- "file_path": edge_data.get("file_path", "unknown_source"),
- "weight": edge_data.get("weight", 1.0),
+ if edge_data is None:
+ return None, []
+
+ # Vector database operation (equally critical, must succeed)
+ if relationships_vdb is not None:
+ data_for_vdb = {
+ compute_mdhash_id(
+ edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"
+ ): {
+ "src_id": edge_data["src_id"],
+ "tgt_id": edge_data["tgt_id"],
+ "keywords": edge_data["keywords"],
+ "content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
+ "source_id": edge_data["source_id"],
+ "file_path": edge_data.get(
+ "file_path", "unknown_source"
+ ),
+ "weight": edge_data.get("weight", 1.0),
+ }
}
- }
- await relationships_vdb.upsert(data_for_vdb)
- return edge_data, added_entities
+
+ # Use safe operation wrapper - VDB failure must throw exception
+ await safe_vdb_operation_with_exception(
+ operation=lambda: relationships_vdb.upsert(data_for_vdb),
+ operation_name="relationship_upsert",
+ entity_name=f"{edge_data['src_id']}-{edge_data['tgt_id']}",
+ max_retries=3,
+ retry_delay=0.1,
+ )
+
+ return edge_data, added_entities
+
+ except Exception as e:
+ # Any database operation failure is critical
+ error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}"
+ logger.error(error_msg)
+
+ # Try to update pipeline status, but don't let status update failure affect main exception
+ try:
+ if (
+ pipeline_status is not None
+ and pipeline_status_lock is not None
+ ):
+ async with pipeline_status_lock:
+ pipeline_status["latest_message"] = error_msg
+ pipeline_status["history_messages"].append(error_msg)
+ except Exception as status_error:
+ logger.error(
+ f"Failed to update pipeline status: {status_error}"
+ )
+
+ # Re-raise the original exception
+ raise
# Create relationship processing tasks
edge_tasks = []
@@ -1402,29 +1760,7 @@ async def merge_nodes_and_edges(
if full_entities_storage and full_relations_storage and doc_id:
try:
# Merge all entities: original entities + entities added during edge processing
- existing_entites_data = None
- existing_relations_data = None
-
- try:
- existing_entites_data = await full_entities_storage.get_by_id(doc_id)
- existing_relations_data = await full_relations_storage.get_by_id(doc_id)
- except Exception as e:
- logger.debug(
- f"Could not retrieve existing entity/relation data for {doc_id}: {e}"
- )
-
- existing_entites_names = set()
- if existing_entites_data and existing_entites_data.get("entity_names"):
- existing_entites_names.update(existing_entites_data["entity_names"])
-
- existing_relation_pairs = set()
- if existing_relations_data and existing_relations_data.get(
- "relation_pairs"
- ):
- for pair in existing_relations_data["relation_pairs"]:
- existing_relation_pairs.add(tuple(sorted(pair)))
-
- final_entity_names = existing_entites_names.copy()
+ final_entity_names = set()
# Add original processed entities
for entity_data in processed_entities:
@@ -1437,7 +1773,7 @@ async def merge_nodes_and_edges(
final_entity_names.add(added_entity["entity_name"])
# Collect all relation pairs
- final_relation_pairs = existing_relation_pairs.copy()
+ final_relation_pairs = set()
for edge_data in processed_edges:
if edge_data:
src_id = edge_data.get("src_id")
@@ -1447,12 +1783,6 @@ async def merge_nodes_and_edges(
final_relation_pairs.add(relation_pair)
log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}"
- new_entities_count = len(final_entity_names) - len(existing_entites_names)
- new_relation_count = len(final_relation_pairs) - len(
- existing_relation_pairs
- )
-
- log_message = f"Phase 3: Merging storage - existing: {len(existing_entites_names)} entitites, {len(existing_relation_pairs)} relations; new: {new_entities_count} entities. {new_relation_count} relations; total: {len(final_entity_names)} entities, {len(final_relation_pairs)} relations"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
@@ -1491,7 +1821,7 @@ async def merge_nodes_and_edges(
)
# Don't raise exception to avoid affecting main flow
- log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} added entities, {len(processed_edges)} relations"
+ log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
@@ -1511,19 +1841,12 @@ async def extract_entities(
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
- language = global_config["addon_params"].get(
- "language", PROMPTS["DEFAULT_LANGUAGE"]
- )
+ language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE)
entity_types = global_config["addon_params"].get(
- "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
+ "entity_types", DEFAULT_ENTITY_TYPES
)
- example_number = global_config["addon_params"].get("example_number", None)
- if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
- examples = "\n".join(
- PROMPTS["entity_extraction_examples"][: int(example_number)]
- )
- else:
- examples = "\n".join(PROMPTS["entity_extraction_examples"])
+
+ examples = "\n".join(PROMPTS["entity_extraction_examples"])
example_context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
@@ -1546,56 +1869,10 @@ async def extract_entities(
)
continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base)
- if_loop_prompt = PROMPTS["entity_if_loop_extraction"]
processed_chunks = 0
total_chunks = len(ordered_chunks)
- async def _process_extraction_result(
- result: str, chunk_key: str, file_path: str = "unknown_source"
- ):
- """Process a single extraction result (either initial or gleaning)
- Args:
- result (str): The extraction result to process
- chunk_key (str): The chunk key for source tracking
- file_path (str): The file path for citation
- Returns:
- tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
- """
- maybe_nodes = defaultdict(list)
- maybe_edges = defaultdict(list)
-
- records = split_string_by_multi_markers(
- result,
- [context_base["record_delimiter"], context_base["completion_delimiter"]],
- )
-
- for record in records:
- record = re.search(r"\((.*)\)", record)
- if record is None:
- continue
- record = record.group(1)
- record_attributes = split_string_by_multi_markers(
- record, [context_base["tuple_delimiter"]]
- )
-
- if_entities = await _handle_single_entity_extraction(
- record_attributes, chunk_key, file_path
- )
- if if_entities is not None:
- maybe_nodes[if_entities["entity_name"]].append(if_entities)
- continue
-
- if_relation = await _handle_single_relationship_extraction(
- record_attributes, chunk_key, file_path
- )
- if if_relation is not None:
- maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
- if_relation
- )
-
- return maybe_nodes, maybe_edges
-
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
"""Process a single chunk
Args:
@@ -1633,11 +1910,16 @@ async def extract_entities(
# Process initial extraction with file path
maybe_nodes, maybe_edges = await _process_extraction_result(
- final_result, chunk_key, file_path
+ final_result,
+ chunk_key,
+ file_path,
+ tuple_delimiter=context_base["tuple_delimiter"],
+ record_delimiter=context_base["record_delimiter"],
+ completion_delimiter=context_base["completion_delimiter"],
)
# Process additional gleaning results
- for now_glean_index in range(entity_extract_max_gleaning):
+ if entity_extract_max_gleaning > 0:
glean_result = await use_llm_func_with_cache(
continue_prompt,
use_llm_func,
@@ -1652,7 +1934,12 @@ async def extract_entities(
# Process gleaning result separately with file path
glean_nodes, glean_edges = await _process_extraction_result(
- glean_result, chunk_key, file_path
+ glean_result,
+ chunk_key,
+ file_path,
+ tuple_delimiter=context_base["tuple_delimiter"],
+ record_delimiter=context_base["record_delimiter"],
+ completion_delimiter=context_base["completion_delimiter"],
)
# Merge results - only add entities and edges with new names
@@ -1660,28 +1947,15 @@ async def extract_entities(
if (
entity_name not in maybe_nodes
): # Only accetp entities with new name in gleaning stage
+ maybe_nodes[entity_name] = [] # Explicitly create the list
maybe_nodes[entity_name].extend(entities)
for edge_key, edges in glean_edges.items():
if (
edge_key not in maybe_edges
): # Only accetp edges with new name in gleaning stage
+ maybe_edges[edge_key] = [] # Explicitly create the list
maybe_edges[edge_key].extend(edges)
- if now_glean_index == entity_extract_max_gleaning - 1:
- break
-
- if_loop_result: str = await use_llm_func_with_cache(
- if_loop_prompt,
- use_llm_func,
- llm_response_cache=llm_response_cache,
- history_messages=history,
- cache_type="extract",
- cache_keys_collector=cache_keys_collector,
- )
- if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
- if if_loop_result != "yes":
- break
-
# Batch update chunk's llm_cache_list with all collected cache keys
if cache_keys_collector and text_chunks_storage:
await update_chunk_cache_list(
@@ -1755,6 +2029,9 @@ async def kg_query(
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
) -> str | AsyncIterator[str]:
+ if not query:
+ return PROMPTS["fail_response"]
+
if query_param.model_func:
use_model_func = query_param.model_func
else:
@@ -1791,21 +2068,16 @@ async def kg_query(
logger.debug(f"Low-level keywords: {ll_keywords}")
# Handle empty keywords
+ if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]:
+ logger.warning("low_level_keywords is empty")
+ if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]:
+ logger.warning("high_level_keywords is empty")
if hl_keywords == [] and ll_keywords == []:
- logger.warning("low_level_keywords and high_level_keywords is empty")
- return PROMPTS["fail_response"]
- if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
- logger.warning(
- "low_level_keywords is empty, switching from %s mode to global mode",
- query_param.mode,
- )
- query_param.mode = "global"
- if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
- logger.warning(
- "high_level_keywords is empty, switching from %s mode to local mode",
- query_param.mode,
- )
- query_param.mode = "local"
+ if len(query) < 50:
+ logger.warning(f"Forced low_level_keywords to origin query: {query}")
+ ll_keywords = [query]
+ else:
+ return PROMPTS["fail_response"]
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
@@ -1970,16 +2242,9 @@ async def extract_keywords_only(
)
# 2. Build the examples
- example_number = global_config["addon_params"].get("example_number", None)
- if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
- examples = "\n".join(
- PROMPTS["keywords_extraction_examples"][: int(example_number)]
- )
- else:
- examples = "\n".join(PROMPTS["keywords_extraction_examples"])
- language = global_config["addon_params"].get(
- "language", PROMPTS["DEFAULT_LANGUAGE"]
- )
+ examples = "\n".join(PROMPTS["keywords_extraction_examples"])
+
+ language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE)
# 3. Process conversation history
# history_context = ""
@@ -2066,6 +2331,7 @@ async def _get_vector_context(
query: str,
chunks_vdb: BaseVectorStorage,
query_param: QueryParam,
+ query_embedding: list[float] = None,
) -> list[dict]:
"""
Retrieve text chunks from the vector database without reranking or truncation.
@@ -2077,6 +2343,7 @@ async def _get_vector_context(
query: The query string to search for
chunks_vdb: Vector database containing document chunks
query_param: Query parameters including chunk_top_k and ids
+ query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls
Returns:
List of text chunks with metadata
@@ -2085,8 +2352,11 @@ async def _get_vector_context(
# Use chunk_top_k if specified, otherwise fall back to top_k
search_top_k = query_param.chunk_top_k or query_param.top_k
- results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
+ results = await chunks_vdb.query(
+ query, top_k=search_top_k, query_embedding=query_embedding
+ )
if not results:
+ logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})")
return []
valid_chunks = []
@@ -2122,6 +2392,10 @@ async def _build_query_context(
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None,
):
+ if not query:
+ logger.warning("Query is empty, skipping context building")
+ return ""
+
logger.info(f"Process {os.getpid()} building query context...")
# Collect chunks from different sources separately
@@ -2140,8 +2414,26 @@ async def _build_query_context(
# Track chunk sources and metadata for final logging
chunk_tracking = {} # chunk_id -> {source, frequency, order}
+ # Pre-compute query embedding once for all vector operations
+ kg_chunk_pick_method = text_chunks_db.global_config.get(
+ "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
+ )
+ query_embedding = None
+ if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
+ embedding_func_config = text_chunks_db.embedding_func
+ if embedding_func_config and embedding_func_config.func:
+ try:
+ query_embedding = await embedding_func_config.func([query])
+ query_embedding = query_embedding[
+ 0
+ ] # Extract first embedding from batch result
+ logger.debug("Pre-computed query embedding for all vector operations")
+ except Exception as e:
+ logger.warning(f"Failed to pre-compute query embedding: {e}")
+ query_embedding = None
+
# Handle local and global modes
- if query_param.mode == "local":
+ if query_param.mode == "local" and len(ll_keywords) > 0:
local_entities, local_relations = await _get_node_data(
ll_keywords,
knowledge_graph_inst,
@@ -2149,7 +2441,7 @@ async def _build_query_context(
query_param,
)
- elif query_param.mode == "global":
+ elif query_param.mode == "global" and len(hl_keywords) > 0:
global_relations, global_entities = await _get_edge_data(
hl_keywords,
knowledge_graph_inst,
@@ -2158,18 +2450,20 @@ async def _build_query_context(
)
else: # hybrid or mix mode
- local_entities, local_relations = await _get_node_data(
- ll_keywords,
- knowledge_graph_inst,
- entities_vdb,
- query_param,
- )
- global_relations, global_entities = await _get_edge_data(
- hl_keywords,
- knowledge_graph_inst,
- relationships_vdb,
- query_param,
- )
+ if len(ll_keywords) > 0:
+ local_entities, local_relations = await _get_node_data(
+ ll_keywords,
+ knowledge_graph_inst,
+ entities_vdb,
+ query_param,
+ )
+ if len(hl_keywords) > 0:
+ global_relations, global_entities = await _get_edge_data(
+ hl_keywords,
+ knowledge_graph_inst,
+ relationships_vdb,
+ query_param,
+ )
# Get vector chunks first if in mix mode
if query_param.mode == "mix" and chunks_vdb:
@@ -2177,6 +2471,7 @@ async def _build_query_context(
query,
chunks_vdb,
query_param,
+ query_embedding,
)
# Track vector chunks with source metadata
for i, chunk in enumerate(vector_chunks):
@@ -2187,6 +2482,8 @@ async def _build_query_context(
"frequency": 1, # Vector chunks always have frequency 1
"order": i + 1, # 1-based order in vector search results
}
+ else:
+ logger.warning(f"Vector chunk missing chunk_id: {chunk}")
# Use round-robin merge to combine local and global data fairly
final_entities = []
@@ -2400,6 +2697,7 @@ async def _build_query_context(
query,
chunks_vdb,
chunk_tracking=chunk_tracking,
+ query_embedding=query_embedding,
)
# Find deduplcicated chunks from edge
@@ -2413,6 +2711,7 @@ async def _build_query_context(
query,
chunks_vdb,
chunk_tracking=chunk_tracking,
+ query_embedding=query_embedding,
)
# Round-robin merge chunks from different sources with deduplication by chunk_id
@@ -2470,6 +2769,7 @@ async def _build_query_context(
# Apply token processing to merged chunks
text_units_context = []
+ truncated_chunks = []
if merged_chunks:
# Calculate dynamic token limit for text chunks
entities_str = json.dumps(entities_context, ensure_ascii=False)
@@ -2501,15 +2801,15 @@ async def _build_query_context(
kg_context_tokens = len(tokenizer.encode(kg_context))
# Calculate actual system prompt overhead dynamically
- # 1. Calculate conversation history tokens
+ # 1. Converstion history not included in context length calculation
history_context = ""
- if query_param.conversation_history:
- history_context = get_conversation_turns(
- query_param.conversation_history, query_param.history_turns
- )
- history_tokens = (
- len(tokenizer.encode(history_context)) if history_context else 0
- )
+ # if query_param.conversation_history:
+ # history_context = get_conversation_turns(
+ # query_param.conversation_history, query_param.history_turns
+ # )
+ # history_tokens = (
+ # len(tokenizer.encode(history_context)) if history_context else 0
+ # )
# 2. Calculate system prompt template tokens (excluding context_data)
user_prompt = query_param.user_prompt if query_param.user_prompt else ""
@@ -2544,7 +2844,7 @@ async def _build_query_context(
available_chunk_tokens = max_total_tokens - used_tokens
logger.debug(
- f"Token allocation - Total: {max_total_tokens}, History: {history_tokens}, SysPrompt: {sys_prompt_overhead}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}"
+ f"Token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_overhead}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}"
)
# Apply token truncation to chunks using the dynamic limit
@@ -2634,9 +2934,7 @@ async def _get_node_data(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
)
- results = await entities_vdb.query(
- query, top_k=query_param.top_k, ids=query_param.ids
- )
+ results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return [], []
@@ -2747,6 +3045,7 @@ async def _find_related_text_unit_from_entities(
query: str = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
+ query_embedding=None,
):
"""
Find text chunks related to entities using configurable chunk selection method.
@@ -2842,6 +3141,7 @@ async def _find_related_text_unit_from_entities(
num_of_chunks=num_of_chunks,
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
+ query_embedding=query_embedding,
)
if selected_chunk_ids == []:
@@ -2910,9 +3210,7 @@ async def _get_edge_data(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
)
- results = await relationships_vdb.query(
- keywords, top_k=query_param.top_k, ids=query_param.ids
- )
+ results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):
return [], []
@@ -2999,6 +3297,7 @@ async def _find_related_text_unit_from_relations(
query: str = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
+ query_embedding=None,
):
"""
Find text chunks related to relationships using configurable chunk selection method.
@@ -3134,6 +3433,7 @@ async def _find_related_text_unit_from_relations(
num_of_chunks=num_of_chunks,
entity_info=relations_with_chunks,
embedding_func=actual_embedding_func,
+ query_embedding=query_embedding,
)
if selected_chunk_ids == []:
@@ -3233,7 +3533,7 @@ async def naive_query(
tokenizer: Tokenizer = global_config["tokenizer"]
- chunks = await _get_vector_context(query, chunks_vdb, query_param)
+ chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0:
return PROMPTS["fail_response"]
diff --git a/lightrag/prompt.py b/lightrag/prompt.py
index 32666bb5..0d21375a 100644
--- a/lightrag/prompt.py
+++ b/lightrag/prompt.py
@@ -4,60 +4,57 @@ from typing import Any
PROMPTS: dict[str, Any] = {}
-PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
-PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
-
PROMPTS["DEFAULT_USER_PROMPT"] = "n/a"
-PROMPTS["entity_extraction"] = """---Goal---
-Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Use {language} as output language.
+PROMPTS["entity_extraction"] = """---Task---
+Given a text document and a list of entity types, identify all entities of those types and all relationships among the identified entities.
----Steps---
-1. Identify all entities. For each identified entity, extract the following information:
-- entity_name: Name of the entity, use same language as input text. If English, capitalized the name
-- entity_type: One of the following types: [{entity_types}]
-- entity_description: Provide a comprehensive description of the entity's attributes and activities *based solely on the information present in the input text*. **Do not infer or hallucinate information not explicitly stated.** If the text provides insufficient information to create a comprehensive description, state "Description not available in text."
-Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
-
-2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
+---Instructions---
+1. Recognizing definitively conceptualized entities in text. For each identified entity, extract the following information:
+ - entity_name: Name of the entity, use same language as input text. If English, capitalized the name
+ - entity_type: Categorize the entity using the provided `Entity_types` list. If a suitable category cannot be determined, classify it as "Other".
+ - entity_description: Provide a comprehensive description of the entity's attributes and activities based on the information present in the input text. To ensure clarity and precision, all descriptions must replace pronouns and referential terms (e.g., "this document," "our company," "I," "you," "he/she") with the specific nouns they represent.
+2. Format each entity as: ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
+3. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are directly and clearly related based on the text. Unsubstantiated relationships must be excluded from the output.
For each pair of related entities, extract the following information:
-- source_entity: name of the source entity, as identified in step 1
-- target_entity: name of the target entity, as identified in step 1
-- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
-- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
-- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
-Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
+ - source_entity: name of the source entity, as identified in step 1
+ - target_entity: name of the target entity, as identified in step 1
+ - relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
+ - relationship_description: Explain the nature of the relationship between the source and target entities, providing a clear rationale for their connection
+4. Format each relationship as: ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
+5. Use `{tuple_delimiter}` as field delimiter. Use `{record_delimiter}` as the entity or relation list delimiter.
+6. Return identified entities and relationships in {language}.
+7. Output `{completion_delimiter}` when all the entities and relationships are extracted.
-3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
-Format the content-level key words as ("content_keywords"{tuple_delimiter})
+---Quality Guidelines---
+- Only extract entities that are clearly defined and meaningful in the context
+- Avoid over-interpretation; stick to what is explicitly stated in the text
+- For all output content, explicitly name the subject or object rather than using pronouns
+- Include specific numerical data in entity name when relevant
+- Ensure entity names are consistent throughout the extraction
-4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
-
-5. When finished, output {completion_delimiter}
-
-######################
---Examples---
-######################
{examples}
-#############################
----Real Data---
-######################
+---Input---
Entity_types: [{entity_types}]
Text:
+```
{input_text}
-######################
-Output:"""
+```
+
+---Output---
+"""
PROMPTS["entity_extraction_examples"] = [
- """Example 1:
+ """[Example 1]
-Entity_types: [person, technology, mission, organization, location]
+---Input---
+Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text:
```
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
@@ -69,22 +66,24 @@ The underlying dismissal earlier seemed to falter, replaced by a glimpse of relu
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
```
-Output:
-("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
-("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
-("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
-("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
-("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
-("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
-("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
-("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
-("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
-("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
-("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
-#############################""",
- """Example 2:
+---Output---
+(entity{tuple_delimiter}Alex{tuple_delimiter}person{tuple_delimiter}Alex is a character who experiences frustration and is observant of the dynamics among other characters.){record_delimiter}
+(entity{tuple_delimiter}Taylor{tuple_delimiter}person{tuple_delimiter}Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective.){record_delimiter}
+(entity{tuple_delimiter}Jordan{tuple_delimiter}person{tuple_delimiter}Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device.){record_delimiter}
+(entity{tuple_delimiter}Cruz{tuple_delimiter}person{tuple_delimiter}Cruz is associated with a vision of control and order, influencing the dynamics among other characters.){record_delimiter}
+(entity{tuple_delimiter}The Device{tuple_delimiter}equiment{tuple_delimiter}The Device is central to the story, with potential game-changing implications, and is revered by Taylor.){record_delimiter}
+(relationship{tuple_delimiter}Alex{tuple_delimiter}Taylor{tuple_delimiter}power dynamics, observation{tuple_delimiter}Alex observes Taylor's authoritarian behavior and notes changes in Taylor's attitude toward the device.){record_delimiter}
+(relationship{tuple_delimiter}Alex{tuple_delimiter}Jordan{tuple_delimiter}shared goals, rebellion{tuple_delimiter}Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision.){record_delimiter}
+(relationship{tuple_delimiter}Taylor{tuple_delimiter}Jordan{tuple_delimiter}conflict resolution, mutual respect{tuple_delimiter}Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce.){record_delimiter}
+(relationship{tuple_delimiter}Jordan{tuple_delimiter}Cruz{tuple_delimiter}ideological conflict, rebellion{tuple_delimiter}Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order.){record_delimiter}
+(relationship{tuple_delimiter}Taylor{tuple_delimiter}The Device{tuple_delimiter}reverence, technological significance{tuple_delimiter}Taylor shows reverence towards the device, indicating its importance and potential impact.){record_delimiter}
+{completion_delimiter}
-Entity_types: [company, index, commodity, market_trend, economic_policy, biological]
+""",
+ """[Example 2]
+
+---Input---
+Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text:
```
Stock markets faced a sharp downturn today as tech giants saw significant declines, with the Global Tech Index dropping by 3.4% in midday trading. Analysts attribute the selloff to investor concerns over rising interest rates and regulatory uncertainty.
@@ -96,101 +95,113 @@ Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1
Financial experts are closely watching the Federal Reserve's next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability.
```
-Output:
-("entity"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"index"{tuple_delimiter}"The Global Tech Index tracks the performance of major technology stocks and experienced a 3.4% decline today."){record_delimiter}
-("entity"{tuple_delimiter}"Nexon Technologies"{tuple_delimiter}"company"{tuple_delimiter}"Nexon Technologies is a tech company that saw its stock decline by 7.8% after disappointing earnings."){record_delimiter}
-("entity"{tuple_delimiter}"Omega Energy"{tuple_delimiter}"company"{tuple_delimiter}"Omega Energy is an energy company that gained 2.1% in stock value due to rising oil prices."){record_delimiter}
-("entity"{tuple_delimiter}"Gold Futures"{tuple_delimiter}"commodity"{tuple_delimiter}"Gold futures rose by 1.5%, indicating increased investor interest in safe-haven assets."){record_delimiter}
-("entity"{tuple_delimiter}"Crude Oil"{tuple_delimiter}"commodity"{tuple_delimiter}"Crude oil prices rose to $87.60 per barrel due to supply constraints and strong demand."){record_delimiter}
-("entity"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"market_trend"{tuple_delimiter}"Market selloff refers to the significant decline in stock values due to investor concerns over interest rates and regulations."){record_delimiter}
-("entity"{tuple_delimiter}"Federal Reserve Policy Announcement"{tuple_delimiter}"economic_policy"{tuple_delimiter}"The Federal Reserve's upcoming policy announcement is expected to impact investor confidence and market stability."){record_delimiter}
-("relationship"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"The decline in the Global Tech Index is part of the broader market selloff driven by investor concerns."{tuple_delimiter}"market performance, investor sentiment"{tuple_delimiter}9){record_delimiter}
-("relationship"{tuple_delimiter}"Nexon Technologies"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"Nexon Technologies' stock decline contributed to the overall drop in the Global Tech Index."{tuple_delimiter}"company impact, index movement"{tuple_delimiter}8){record_delimiter}
-("relationship"{tuple_delimiter}"Gold Futures"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"Gold prices rose as investors sought safe-haven assets during the market selloff."{tuple_delimiter}"market reaction, safe-haven investment"{tuple_delimiter}10){record_delimiter}
-("relationship"{tuple_delimiter}"Federal Reserve Policy Announcement"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"Speculation over Federal Reserve policy changes contributed to market volatility and investor selloff."{tuple_delimiter}"interest rate impact, financial regulation"{tuple_delimiter}7){record_delimiter}
-("content_keywords"{tuple_delimiter}"market downturn, investor sentiment, commodities, Federal Reserve, stock performance"){completion_delimiter}
-#############################""",
- """Example 3:
+---Output---
+(entity{tuple_delimiter}Global Tech Index{tuple_delimiter}category{tuple_delimiter}The Global Tech Index tracks the performance of major technology stocks and experienced a 3.4% decline today.){record_delimiter}
+(entity{tuple_delimiter}Nexon Technologies{tuple_delimiter}organization{tuple_delimiter}Nexon Technologies is a tech company that saw its stock decline by 7.8% after disappointing earnings.){record_delimiter}
+(entity{tuple_delimiter}Omega Energy{tuple_delimiter}organization{tuple_delimiter}Omega Energy is an energy company that gained 2.1% in stock value due to rising oil prices.){record_delimiter}
+(entity{tuple_delimiter}Gold Futures{tuple_delimiter}product{tuple_delimiter}Gold futures rose by 1.5%, indicating increased investor interest in safe-haven assets.){record_delimiter}
+(entity{tuple_delimiter}Crude Oil{tuple_delimiter}product{tuple_delimiter}Crude oil prices rose to $87.60 per barrel due to supply constraints and strong demand.){record_delimiter}
+(entity{tuple_delimiter}Market Selloff{tuple_delimiter}category{tuple_delimiter}Market selloff refers to the significant decline in stock values due to investor concerns over interest rates and regulations.){record_delimiter}
+(entity{tuple_delimiter}Federal Reserve Policy Announcement{tuple_delimiter}category{tuple_delimiter}The Federal Reserve's upcoming policy announcement is expected to impact investor confidence and market stability.){record_delimiter}
+(entity{tuple_delimiter}3.4% Decline{tuple_delimiter}category{tuple_delimiter}The Global Tech Index experienced a 3.4% decline in midday trading.){record_delimiter}
+(relationship{tuple_delimiter}Global Tech Index{tuple_delimiter}Market Selloff{tuple_delimiter}market performance, investor sentiment{tuple_delimiter}The decline in the Global Tech Index is part of the broader market selloff driven by investor concerns.){record_delimiter}
+(relationship{tuple_delimiter}Nexon Technologies{tuple_delimiter}Global Tech Index{tuple_delimiter}company impact, index movement{tuple_delimiter}Nexon Technologies' stock decline contributed to the overall drop in the Global Tech Index.){record_delimiter}
+(relationship{tuple_delimiter}Gold Futures{tuple_delimiter}Market Selloff{tuple_delimiter}market reaction, safe-haven investment{tuple_delimiter}Gold prices rose as investors sought safe-haven assets during the market selloff.){record_delimiter}
+(relationship{tuple_delimiter}Federal Reserve Policy Announcement{tuple_delimiter}Market Selloff{tuple_delimiter}interest rate impact, financial regulation{tuple_delimiter}Speculation over Federal Reserve policy changes contributed to market volatility and investor selloff.){record_delimiter}
+{completion_delimiter}
-Entity_types: [economic_policy, athlete, event, location, record, organization, equipment]
+""",
+ """[Example 3]
+
+---Input---
+Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text:
```
At the World Athletics Championship in Tokyo, Noah Carter broke the 100m sprint record using cutting-edge carbon-fiber spikes.
```
-Output:
-("entity"{tuple_delimiter}"World Athletics Championship"{tuple_delimiter}"event"{tuple_delimiter}"The World Athletics Championship is a global sports competition featuring top athletes in track and field."){record_delimiter}
-("entity"{tuple_delimiter}"Tokyo"{tuple_delimiter}"location"{tuple_delimiter}"Tokyo is the host city of the World Athletics Championship."){record_delimiter}
-("entity"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"athlete"{tuple_delimiter}"Noah Carter is a sprinter who set a new record in the 100m sprint at the World Athletics Championship."){record_delimiter}
-("entity"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"record"{tuple_delimiter}"The 100m sprint record is a benchmark in athletics, recently broken by Noah Carter."){record_delimiter}
-("entity"{tuple_delimiter}"Carbon-Fiber Spikes"{tuple_delimiter}"equipment"{tuple_delimiter}"Carbon-fiber spikes are advanced sprinting shoes that provide enhanced speed and traction."){record_delimiter}
-("entity"{tuple_delimiter}"World Athletics Federation"{tuple_delimiter}"organization"{tuple_delimiter}"The World Athletics Federation is the governing body overseeing the World Athletics Championship and record validations."){record_delimiter}
-("relationship"{tuple_delimiter}"World Athletics Championship"{tuple_delimiter}"Tokyo"{tuple_delimiter}"The World Athletics Championship is being hosted in Tokyo."{tuple_delimiter}"event location, international competition"{tuple_delimiter}8){record_delimiter}
-("relationship"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"Noah Carter set a new 100m sprint record at the championship."{tuple_delimiter}"athlete achievement, record-breaking"{tuple_delimiter}10){record_delimiter}
-("relationship"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"Carbon-Fiber Spikes"{tuple_delimiter}"Noah Carter used carbon-fiber spikes to enhance performance during the race."{tuple_delimiter}"athletic equipment, performance boost"{tuple_delimiter}7){record_delimiter}
-("relationship"{tuple_delimiter}"World Athletics Federation"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"The World Athletics Federation is responsible for validating and recognizing new sprint records."{tuple_delimiter}"sports regulation, record certification"{tuple_delimiter}9){record_delimiter}
-("content_keywords"{tuple_delimiter}"athletics, sprinting, record-breaking, sports technology, competition"){completion_delimiter}
-#############################""",
-]
+---Output---
+(entity{tuple_delimiter}World Athletics Championship{tuple_delimiter}event{tuple_delimiter}The World Athletics Championship is a global sports competition featuring top athletes in track and field.){record_delimiter}
+(entity{tuple_delimiter}Tokyo{tuple_delimiter}location{tuple_delimiter}Tokyo is the host city of the World Athletics Championship.){record_delimiter}
+(entity{tuple_delimiter}Noah Carter{tuple_delimiter}person{tuple_delimiter}Noah Carter is a sprinter who set a new record in the 100m sprint at the World Athletics Championship.){record_delimiter}
+(entity{tuple_delimiter}100m Sprint Record{tuple_delimiter}category{tuple_delimiter}The 100m sprint record is a benchmark in athletics, recently broken by Noah Carter.){record_delimiter}
+(entity{tuple_delimiter}Carbon-Fiber Spikes{tuple_delimiter}equipment{tuple_delimiter}Carbon-fiber spikes are advanced sprinting shoes that provide enhanced speed and traction.){record_delimiter}
+(entity{tuple_delimiter}World Athletics Federation{tuple_delimiter}organization{tuple_delimiter}The World Athletics Federation is the governing body overseeing the World Athletics Championship and record validations.){record_delimiter}
+(relationship{tuple_delimiter}World Athletics Championship{tuple_delimiter}Tokyo{tuple_delimiter}event location, international competition{tuple_delimiter}The World Athletics Championship is being hosted in Tokyo.){record_delimiter}
+(relationship{tuple_delimiter}Noah Carter{tuple_delimiter}100m Sprint Record{tuple_delimiter}athlete achievement, record-breaking{tuple_delimiter}Noah Carter set a new 100m sprint record at the championship.){record_delimiter}
+(relationship{tuple_delimiter}Noah Carter{tuple_delimiter}Carbon-Fiber Spikes{tuple_delimiter}athletic equipment, performance boost{tuple_delimiter}Noah Carter used carbon-fiber spikes to enhance performance during the race.){record_delimiter}
+(relationship{tuple_delimiter}Noah Carter{tuple_delimiter}World Athletics Championship{tuple_delimiter}athlete participation, competition{tuple_delimiter}Noah Carter is competing at the World Athletics Championship.){record_delimiter}
+{completion_delimiter}
-PROMPTS[
- "summarize_entity_descriptions"
-] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
-Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
-Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
-If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
-Make sure it is written in third person, and include the entity names so we the have full context.
-Use {language} as output language.
+""",
+ """[Example 4]
-#######
----Data---
-Entities: {entity_name}
-Description List: {description_list}
-#######
-Output:
-"""
-
-PROMPTS["entity_continue_extraction"] = """
-MANY entities and relationships were missed in the last extraction. Please find only the missing entities and relationships from previous text.
-
----Remember Steps---
-
-1. Identify all entities. For each identified entity, extract the following information:
-- entity_name: Name of the entity, use same language as input text. If English, capitalized the name
-- entity_type: One of the following types: [{entity_types}]
-- entity_description: Provide a comprehensive description of the entity's attributes and activities *based solely on the information present in the input text*. **Do not infer or hallucinate information not explicitly stated.** If the text provides insufficient information to create a comprehensive description, state "Description not available in text."
-Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
-
-2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
-For each pair of related entities, extract the following information:
-- source_entity: name of the source entity, as identified in step 1
-- target_entity: name of the target entity, as identified in step 1
-- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
-- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
-- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
-Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
-
-3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
-Format the content-level key words as ("content_keywords"{tuple_delimiter})
-
-4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
-
-5. When finished, output {completion_delimiter}
+---Input---
+Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
+Text:
+```
+在北京举行的人工智能大会上,腾讯公司的首席技术官张伟发布了最新的大语言模型"腾讯智言",该模型在自然语言处理方面取得了重大突破。
+```
---Output---
+(entity{tuple_delimiter}人工智能大会{tuple_delimiter}event{tuple_delimiter}人工智能大会是在北京举行的技术会议,专注于人工智能领域的最新发展。){record_delimiter}
+(entity{tuple_delimiter}北京{tuple_delimiter}location{tuple_delimiter}北京是人工智能大会的举办城市。){record_delimiter}
+(entity{tuple_delimiter}腾讯公司{tuple_delimiter}organization{tuple_delimiter}腾讯公司是参与人工智能大会的科技企业,发布了新的语言模型产品。){record_delimiter}
+(entity{tuple_delimiter}张伟{tuple_delimiter}person{tuple_delimiter}张伟是腾讯公司的首席技术官,在大会上发布了新产品。){record_delimiter}
+(entity{tuple_delimiter}腾讯智言{tuple_delimiter}product{tuple_delimiter}腾讯智言是腾讯公司发布的大语言模型产品,在自然语言处理方面有重大突破。){record_delimiter}
+(entity{tuple_delimiter}自然语言处理技术{tuple_delimiter}technology{tuple_delimiter}自然语言处理技术是腾讯智言模型取得重大突破的技术领域。){record_delimiter}
+(relationship{tuple_delimiter}人工智能大会{tuple_delimiter}北京{tuple_delimiter}会议地点, 举办关系{tuple_delimiter}人工智能大会在北京举行。){record_delimiter}
+(relationship{tuple_delimiter}张伟{tuple_delimiter}腾讯公司{tuple_delimiter}雇佣关系, 高管职位{tuple_delimiter}张伟担任腾讯公司的首席技术官。){record_delimiter}
+(relationship{tuple_delimiter}张伟{tuple_delimiter}腾讯智言{tuple_delimiter}产品发布, 技术展示{tuple_delimiter}张伟在大会上发布了腾讯智言大语言模型。){record_delimiter}
+(relationship{tuple_delimiter}腾讯智言{tuple_delimiter}自然语言处理技术{tuple_delimiter}技术应用, 突破创新{tuple_delimiter}腾讯智言在自然语言处理技术方面取得了重大突破。){record_delimiter}
+{completion_delimiter}
-Add new entities and relations below using the same format, and do not include entities and relations that have been previously extracted. :\n
-""".strip()
+""",
+]
+PROMPTS["summarize_entity_descriptions"] = """---Role---
+You are a Knowledge Graph Specialist responsible for data curation and synthesis.
+
+---Task---
+Your task is to synthesize a list of descriptions of a given entity or relation into a single, comprehensive, and cohesive summary.
+
+---Instructions---
+1. **Comprehensiveness:** The summary must integrate key information from all provided descriptions. Do not omit important facts.
+2. **Context:** The summary must explicitly mention the name of the entity or relation for full context.
+3. **Conflict:** In case of conflicting or inconsistent descriptions, determine if they originate from multiple, distinct entities or relationships that share the same name. If so, summarize each entity or relationship separately and then consolidate all summaries.
+4. **Style:** The output must be written from an objective, third-person perspective.
+5. **Length:** Maintain depth and completeness while ensuring the summary's length not exceed {summary_length} tokens.
+6. **Language:** The entire output must be written in {language}.
+
+---Data---
+{description_type} Name: {description_name}
+Description List:
+{description_list}
+
+---Output---
+"""
+
+PROMPTS["entity_continue_extraction"] = """---Task---
+Identify any missed entities or relationships in the last extraction task.
+
+---Instructions---
+1. Output the entities and realtionships in the same format as previous extraction task.
+2. Do not include entities and relations that have been previously extracted.
+3. If the entity doesn't clearly fit in any of`Entity_types` provided, classify it as "Other".
+4. Return identified entities and relationships in {language}.
+5. Output `{completion_delimiter}` when all the entities and relationships are extracted.
+
+---Output---
+"""
+
+# TODO: Deprecated
PROMPTS["entity_if_loop_extraction"] = """
---Goal---'
-It appears some entities may have still been missed.
+Check if it appears some entities may have still been missed. Output "Yes" if so, otherwise "No".
---Output---
-
-Answer ONLY by `YES` OR `NO` if there are still entities that need to be added.
-""".strip()
+Output:"""
PROMPTS["fail_response"] = (
"Sorry, I'm not able to provide an answer to that question.[no-context]"
@@ -211,7 +222,7 @@ Generate a concise response based on Knowledge Base and follow Response Rules, c
---Knowledge Graph and Document Chunks---
{context_data}
----RESPONSE GUIDELINES---
+---Response Guidelines---
**1. Content & Adherence:**
- Strictly adhere to the provided context from the Knowledge Base. Do not invent, assume, or include any information not present in the source data.
- If the answer cannot be found in the provided context, state that you do not have enough information to answer.
@@ -233,8 +244,8 @@ Generate a concise response based on Knowledge Base and follow Response Rules, c
---USER CONTEXT---
- Additional user prompt: {user_prompt}
-
-Response:"""
+---Response---
+"""
PROMPTS["keywords_extraction"] = """---Role---
You are an expert keyword extractor, specializing in analyzing user queries for a Retrieval-Augmented Generation (RAG) system. Your purpose is to identify both high-level and low-level keywords in the user's query that will be used for effective document retrieval.
@@ -257,7 +268,7 @@ Given a user query, your task is to extract two distinct types of keywords:
User Query: {query}
---Output---
-"""
+Output:"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
@@ -327,5 +338,5 @@ Generate a concise response based on Document Chunks and follow Response Rules,
---USER CONTEXT---
- Additional user prompt: {user_prompt}
-
-Response:"""
+---Response---
+Output:"""
diff --git a/lightrag/rerank.py b/lightrag/rerank.py
index 5ed1ca68..35551f5a 100644
--- a/lightrag/rerank.py
+++ b/lightrag/rerank.py
@@ -2,270 +2,199 @@ from __future__ import annotations
import os
import aiohttp
-from typing import Callable, Any, List, Dict, Optional
-from pydantic import BaseModel, Field
-
+from typing import Any, List, Dict, Optional
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
from .utils import logger
+from dotenv import load_dotenv
-class RerankModel(BaseModel):
- """
- Wrapper for rerank functions that can be used with LightRAG.
-
- Example usage:
- ```python
- from lightrag.rerank import RerankModel, jina_rerank
-
- # Create rerank model
- rerank_model = RerankModel(
- rerank_func=jina_rerank,
- kwargs={
- "model": "BAAI/bge-reranker-v2-m3",
- "api_key": "your_api_key_here",
- "base_url": "https://api.jina.ai/v1/rerank"
- }
- )
-
- # Use in LightRAG
- rag = LightRAG(
- rerank_model_func=rerank_model.rerank,
- # ... other configurations
- )
-
- # Query with rerank enabled (default)
- result = await rag.aquery(
- "your query",
- param=QueryParam(enable_rerank=True)
- )
- ```
-
- Or define a custom function directly:
- ```python
- async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
- return await jina_rerank(
- query=query,
- documents=documents,
- model="BAAI/bge-reranker-v2-m3",
- api_key="your_api_key_here",
- top_n=top_n or 10,
- **kwargs
- )
-
- rag = LightRAG(
- rerank_model_func=my_rerank_func,
- # ... other configurations
- )
-
- # Control rerank per query
- result = await rag.aquery(
- "your query",
- param=QueryParam(enable_rerank=True) # Enable rerank for this query
- )
- ```
- """
-
- rerank_func: Callable[[Any], List[Dict]]
- kwargs: Dict[str, Any] = Field(default_factory=dict)
-
- async def rerank(
- self,
- query: str,
- documents: List[Dict[str, Any]],
- top_n: Optional[int] = None,
- **extra_kwargs,
- ) -> List[Dict[str, Any]]:
- """Rerank documents using the configured model function."""
- # Merge extra kwargs with model kwargs
- kwargs = {**self.kwargs, **extra_kwargs}
- return await self.rerank_func(
- query=query, documents=documents, top_n=top_n, **kwargs
- )
-
-
-class MultiRerankModel(BaseModel):
- """Multiple rerank models for different modes/scenarios."""
-
- # Primary rerank model (used if mode-specific models are not defined)
- rerank_model: Optional[RerankModel] = None
-
- # Mode-specific rerank models
- entity_rerank_model: Optional[RerankModel] = None
- relation_rerank_model: Optional[RerankModel] = None
- chunk_rerank_model: Optional[RerankModel] = None
-
- async def rerank(
- self,
- query: str,
- documents: List[Dict[str, Any]],
- mode: str = "default",
- top_n: Optional[int] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """Rerank using the appropriate model based on mode."""
-
- # Select model based on mode
- if mode == "entity" and self.entity_rerank_model:
- model = self.entity_rerank_model
- elif mode == "relation" and self.relation_rerank_model:
- model = self.relation_rerank_model
- elif mode == "chunk" and self.chunk_rerank_model:
- model = self.chunk_rerank_model
- elif self.rerank_model:
- model = self.rerank_model
- else:
- logger.warning(f"No rerank model available for mode: {mode}")
- return documents
-
- return await model.rerank(query, documents, top_n, **kwargs)
+# use the .env that is inside the current folder
+# allows to use different .env file for each lightrag instance
+# the OS environment variables take precedence over the .env file
+load_dotenv(dotenv_path=".env", override=False)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=(
+ retry_if_exception_type(aiohttp.ClientError)
+ | retry_if_exception_type(aiohttp.ClientResponseError)
+ ),
+)
async def generic_rerank_api(
query: str,
- documents: List[Dict[str, Any]],
+ documents: List[str],
model: str,
base_url: str,
- api_key: str,
+ api_key: Optional[str],
top_n: Optional[int] = None,
- **kwargs,
+ return_documents: Optional[bool] = None,
+ extra_body: Optional[Dict[str, Any]] = None,
+ response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
+ request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
) -> List[Dict[str, Any]]:
"""
- Generic rerank function that works with Jina/Cohere compatible APIs.
+ Generic rerank API call for Jina/Cohere/Aliyun models.
Args:
query: The search query
- documents: List of documents to rerank
- model: Model identifier
+ documents: List of strings to rerank
+ model: Model name to use
base_url: API endpoint URL
- api_key: API authentication key
+ api_key: API key for authentication
top_n: Number of top results to return
- **kwargs: Additional API-specific parameters
+ return_documents: Whether to return document text (Jina only)
+ extra_body: Additional body parameters
+ response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
Returns:
- List of reranked documents with relevance scores
+ List of dictionary of ["index": int, "relevance_score": float]
"""
- if not api_key:
- logger.warning("No API key provided for rerank service")
- return documents
+ if not base_url:
+ raise ValueError("Base URL is required")
- if not documents:
- return documents
+ headers = {"Content-Type": "application/json"}
+ if api_key is not None:
+ headers["Authorization"] = f"Bearer {api_key}"
- # Prepare documents for reranking - handle both text and dict formats
- prepared_docs = []
- for doc in documents:
- if isinstance(doc, dict):
- # Use 'content' field if available, otherwise use 'text' or convert to string
- text = doc.get("content") or doc.get("text") or str(doc)
- else:
- text = str(doc)
- prepared_docs.append(text)
+ # Build request payload based on request format
+ if request_format == "aliyun":
+ # Aliyun format: nested input/parameters structure
+ payload = {
+ "model": model,
+ "input": {
+ "query": query,
+ "documents": documents,
+ },
+ "parameters": {},
+ }
- # Prepare request
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
+ # Add optional parameters to parameters object
+ if top_n is not None:
+ payload["parameters"]["top_n"] = top_n
- data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
+ if return_documents is not None:
+ payload["parameters"]["return_documents"] = return_documents
- if top_n is not None:
- data["top_n"] = min(top_n, len(prepared_docs))
+ # Add extra parameters to parameters object
+ if extra_body:
+ payload["parameters"].update(extra_body)
+ else:
+ # Standard format for Jina/Cohere
+ payload = {
+ "model": model,
+ "query": query,
+ "documents": documents,
+ }
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(base_url, headers=headers, json=data) as response:
- if response.status != 200:
- error_text = await response.text()
- logger.error(f"Rerank API error {response.status}: {error_text}")
- return documents
+ # Add optional parameters
+ if top_n is not None:
+ payload["top_n"] = top_n
- result = await response.json()
+ # Only Jina API supports return_documents parameter
+ if return_documents is not None:
+ payload["return_documents"] = return_documents
- # Extract reranked results
- if "results" in result:
- # Standard format: results contain index and relevance_score
- reranked_docs = []
- for item in result["results"]:
- if "index" in item:
- doc_idx = item["index"]
- if 0 <= doc_idx < len(documents):
- reranked_doc = documents[doc_idx].copy()
- if "relevance_score" in item:
- reranked_doc["rerank_score"] = item[
- "relevance_score"
- ]
- reranked_docs.append(reranked_doc)
- return reranked_docs
- else:
- logger.warning("Unexpected rerank API response format")
- return documents
+ # Add extra parameters
+ if extra_body:
+ payload.update(extra_body)
- except Exception as e:
- logger.error(f"Error during reranking: {e}")
- return documents
-
-
-async def jina_rerank(
- query: str,
- documents: List[Dict[str, Any]],
- model: str = "BAAI/bge-reranker-v2-m3",
- top_n: Optional[int] = None,
- base_url: str = "https://api.jina.ai/v1/rerank",
- api_key: Optional[str] = None,
- **kwargs,
-) -> List[Dict[str, Any]]:
- """
- Rerank documents using Jina AI API.
-
- Args:
- query: The search query
- documents: List of documents to rerank
- model: Jina rerank model name
- top_n: Number of top results to return
- base_url: Jina API endpoint
- api_key: Jina API key
- **kwargs: Additional parameters
-
- Returns:
- List of reranked documents with relevance scores
- """
- if api_key is None:
- api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
-
- return await generic_rerank_api(
- query=query,
- documents=documents,
- model=model,
- base_url=base_url,
- api_key=api_key,
- top_n=top_n,
- **kwargs,
+ logger.debug(
+ f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
)
+ async with aiohttp.ClientSession() as session:
+ async with session.post(base_url, headers=headers, json=payload) as response:
+ if response.status != 200:
+ error_text = await response.text()
+ content_type = response.headers.get("content-type", "").lower()
+ is_html_error = (
+ error_text.strip().startswith("")
+ or "text/html" in content_type
+ )
+ if is_html_error:
+ if response.status == 502:
+ clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
+ elif response.status == 503:
+ clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
+ elif response.status == 504:
+ clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
+ else:
+ clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
+ else:
+ clean_error = error_text
+ logger.error(f"Rerank API error {response.status}: {clean_error}")
+ raise aiohttp.ClientResponseError(
+ request_info=response.request_info,
+ history=response.history,
+ status=response.status,
+ message=f"Rerank API error: {clean_error}",
+ )
+
+ response_json = await response.json()
+
+ if response_format == "aliyun":
+ # Aliyun format: {"output": {"results": [...]}}
+ results = response_json.get("output", {}).get("results", [])
+ if not isinstance(results, list):
+ logger.warning(
+ f"Expected 'output.results' to be list, got {type(results)}: {results}"
+ )
+ results = []
+
+ elif response_format == "standard":
+ # Standard format: {"results": [...]}
+ results = response_json.get("results", [])
+ if not isinstance(results, list):
+ logger.warning(
+ f"Expected 'results' to be list, got {type(results)}: {results}"
+ )
+ results = []
+ else:
+ raise ValueError(f"Unsupported response format: {response_format}")
+ if not results:
+ logger.warning("Rerank API returned empty results")
+ return []
+
+ # Standardize return format
+ return [
+ {"index": result["index"], "relevance_score": result["relevance_score"]}
+ for result in results
+ ]
+
async def cohere_rerank(
query: str,
- documents: List[Dict[str, Any]],
- model: str = "rerank-english-v2.0",
+ documents: List[str],
top_n: Optional[int] = None,
- base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None,
- **kwargs,
+ model: str = "rerank-v3.5",
+ base_url: str = "https://api.cohere.com/v2/rerank",
+ extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Args:
query: The search query
- documents: List of documents to rerank
- model: Cohere rerank model name
+ documents: List of strings to rerank
top_n: Number of top results to return
- base_url: Cohere API endpoint
- api_key: Cohere API key
- **kwargs: Additional parameters
+ api_key: API key
+ model: rerank model name
+ base_url: API endpoint
+ extra_body: Additional body for http request(reserved for extra params)
Returns:
- List of reranked documents with relevance scores
+ List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
- api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
+ api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
@@ -274,24 +203,39 @@ async def cohere_rerank(
base_url=base_url,
api_key=api_key,
top_n=top_n,
- **kwargs,
+ return_documents=None, # Cohere doesn't support this parameter
+ extra_body=extra_body,
+ response_format="standard",
)
-# Convenience function for custom API endpoints
-async def custom_rerank(
+async def jina_rerank(
query: str,
- documents: List[Dict[str, Any]],
- model: str,
- base_url: str,
- api_key: str,
+ documents: List[str],
top_n: Optional[int] = None,
- **kwargs,
+ api_key: Optional[str] = None,
+ model: str = "jina-reranker-v2-base-multilingual",
+ base_url: str = "https://api.jina.ai/v1/rerank",
+ extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
- Rerank documents using a custom API endpoint.
- This is useful for self-hosted or custom rerank services.
+ Rerank documents using Jina AI API.
+
+ Args:
+ query: The search query
+ documents: List of strings to rerank
+ top_n: Number of top results to return
+ api_key: API key
+ model: rerank model name
+ base_url: API endpoint
+ extra_body: Additional body for http request(reserved for extra params)
+
+ Returns:
+ List of dictionary of ["index": int, "relevance_score": float]
"""
+ if api_key is None:
+ api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
+
return await generic_rerank_api(
query=query,
documents=documents,
@@ -299,26 +243,112 @@ async def custom_rerank(
base_url=base_url,
api_key=api_key,
top_n=top_n,
- **kwargs,
+ return_documents=False,
+ extra_body=extra_body,
+ response_format="standard",
)
+async def ali_rerank(
+ query: str,
+ documents: List[str],
+ top_n: Optional[int] = None,
+ api_key: Optional[str] = None,
+ model: str = "gte-rerank-v2",
+ base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
+ extra_body: Optional[Dict[str, Any]] = None,
+) -> List[Dict[str, Any]]:
+ """
+ Rerank documents using Aliyun DashScope API.
+
+ Args:
+ query: The search query
+ documents: List of strings to rerank
+ top_n: Number of top results to return
+ api_key: Aliyun API key
+ model: rerank model name
+ base_url: API endpoint
+ extra_body: Additional body for http request(reserved for extra params)
+
+ Returns:
+ List of dictionary of ["index": int, "relevance_score": float]
+ """
+ if api_key is None:
+ api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
+
+ return await generic_rerank_api(
+ query=query,
+ documents=documents,
+ model=model,
+ base_url=base_url,
+ api_key=api_key,
+ top_n=top_n,
+ return_documents=False, # Aliyun doesn't need this parameter
+ extra_body=extra_body,
+ response_format="aliyun",
+ request_format="aliyun",
+ )
+
+
+"""Please run this test as a module:
+python -m lightrag.rerank
+"""
if __name__ == "__main__":
import asyncio
async def main():
- # Example usage
+ # Example usage - documents should be strings, not dictionaries
docs = [
- {"content": "The capital of France is Paris."},
- {"content": "Tokyo is the capital of Japan."},
- {"content": "London is the capital of England."},
+ "The capital of France is Paris.",
+ "Tokyo is the capital of Japan.",
+ "London is the capital of England.",
]
query = "What is the capital of France?"
- result = await jina_rerank(
- query=query, documents=docs, top_n=2, api_key="your-api-key-here"
- )
- print(result)
+ # Test Jina rerank
+ try:
+ print("=== Jina Rerank ===")
+ result = await jina_rerank(
+ query=query,
+ documents=docs,
+ top_n=2,
+ )
+ print("Results:")
+ for item in result:
+ print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
+ print(f"Document: {docs[item['index']]}")
+ except Exception as e:
+ print(f"Jina Error: {e}")
+
+ # Test Cohere rerank
+ try:
+ print("\n=== Cohere Rerank ===")
+ result = await cohere_rerank(
+ query=query,
+ documents=docs,
+ top_n=2,
+ )
+ print("Results:")
+ for item in result:
+ print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
+ print(f"Document: {docs[item['index']]}")
+ except Exception as e:
+ print(f"Cohere Error: {e}")
+
+ # Test Aliyun rerank
+ try:
+ print("\n=== Aliyun Rerank ===")
+ result = await ali_rerank(
+ query=query,
+ documents=docs,
+ top_n=2,
+ )
+ print("Results:")
+ for item in result:
+ print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
+ print(f"Document: {docs[item['index']]}")
+ except Exception as e:
+ print(f"Aliyun Error: {e}")
asyncio.run(main())
diff --git a/lightrag/tools/check_initialization.py b/lightrag/tools/check_initialization.py
new file mode 100644
index 00000000..6bcb17e3
--- /dev/null
+++ b/lightrag/tools/check_initialization.py
@@ -0,0 +1,180 @@
+#!/usr/bin/env python3
+"""
+Diagnostic tool to check LightRAG initialization status.
+
+This tool helps developers verify that their LightRAG instance is properly
+initialized before use, preventing common initialization errors.
+
+Usage:
+ python -m lightrag.tools.check_initialization
+"""
+
+import asyncio
+import sys
+from pathlib import Path
+
+# Add parent directory to path for imports
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from lightrag import LightRAG
+from lightrag.base import StoragesStatus
+
+
+async def check_lightrag_setup(rag_instance: LightRAG, verbose: bool = False) -> bool:
+ """
+ Check if a LightRAG instance is properly initialized.
+
+ Args:
+ rag_instance: The LightRAG instance to check
+ verbose: If True, print detailed diagnostic information
+
+ Returns:
+ True if properly initialized, False otherwise
+ """
+ issues = []
+ warnings = []
+
+ print("🔍 Checking LightRAG initialization status...\n")
+
+ # Check storage initialization status
+ if not hasattr(rag_instance, "_storages_status"):
+ issues.append("LightRAG instance missing _storages_status attribute")
+ elif rag_instance._storages_status != StoragesStatus.INITIALIZED:
+ issues.append(
+ f"Storages not initialized (status: {rag_instance._storages_status.name})"
+ )
+ else:
+ print("✅ Storage status: INITIALIZED")
+
+ # Check individual storage components
+ storage_components = [
+ ("full_docs", "Document storage"),
+ ("text_chunks", "Text chunks storage"),
+ ("entities_vdb", "Entity vector database"),
+ ("relationships_vdb", "Relationship vector database"),
+ ("chunks_vdb", "Chunks vector database"),
+ ("doc_status", "Document status tracker"),
+ ("llm_response_cache", "LLM response cache"),
+ ("full_entities", "Entity storage"),
+ ("full_relations", "Relation storage"),
+ ("chunk_entity_relation_graph", "Graph storage"),
+ ]
+
+ if verbose:
+ print("\n📦 Storage Components:")
+
+ for component, description in storage_components:
+ if not hasattr(rag_instance, component):
+ issues.append(f"Missing storage component: {component} ({description})")
+ else:
+ storage = getattr(rag_instance, component)
+ if storage is None:
+ warnings.append(f"Storage {component} is None (might be optional)")
+ elif hasattr(storage, "_storage_lock"):
+ if storage._storage_lock is None:
+ issues.append(f"Storage {component} not initialized (lock is None)")
+ elif verbose:
+ print(f" ✅ {description}: Ready")
+ elif verbose:
+ print(f" ✅ {description}: Ready")
+
+ # Check pipeline status
+ try:
+ from lightrag.kg.shared_storage import get_namespace_data
+
+ get_namespace_data("pipeline_status")
+ print("✅ Pipeline status: INITIALIZED")
+ except KeyError:
+ issues.append(
+ "Pipeline status not initialized - call initialize_pipeline_status()"
+ )
+ except Exception as e:
+ issues.append(f"Error checking pipeline status: {str(e)}")
+
+ # Print results
+ print("\n" + "=" * 50)
+
+ if issues:
+ print("❌ Issues found:\n")
+ for issue in issues:
+ print(f" • {issue}")
+
+ print("\n📝 To fix, run this initialization sequence:\n")
+ print(" await rag.initialize_storages()")
+ print(" from lightrag.kg.shared_storage import initialize_pipeline_status")
+ print(" await initialize_pipeline_status()")
+ print(
+ "\n📚 Documentation: https://github.com/HKUDS/LightRAG#important-initialization-requirements"
+ )
+
+ if warnings and verbose:
+ print("\n⚠️ Warnings (might be normal):")
+ for warning in warnings:
+ print(f" • {warning}")
+
+ return False
+ else:
+ print("✅ LightRAG is properly initialized and ready to use!")
+
+ if warnings and verbose:
+ print("\n⚠️ Warnings (might be normal):")
+ for warning in warnings:
+ print(f" • {warning}")
+
+ return True
+
+
+async def demo():
+ """Demonstrate the diagnostic tool with a test instance."""
+ from lightrag.llm.openai import openai_embed, gpt_4o_mini_complete
+ from lightrag.kg.shared_storage import initialize_pipeline_status
+
+ print("=" * 50)
+ print("LightRAG Initialization Diagnostic Tool")
+ print("=" * 50)
+
+ # Create test instance
+ rag = LightRAG(
+ working_dir="./test_diagnostic",
+ embedding_func=openai_embed,
+ llm_model_func=gpt_4o_mini_complete,
+ )
+
+ print("\n🔴 BEFORE initialization:\n")
+ await check_lightrag_setup(rag, verbose=True)
+
+ print("\n" + "=" * 50)
+ print("\n🔄 Initializing...\n")
+ await rag.initialize_storages()
+ await initialize_pipeline_status()
+
+ print("\n🟢 AFTER initialization:\n")
+ await check_lightrag_setup(rag, verbose=True)
+
+ # Cleanup
+ import shutil
+
+ shutil.rmtree("./test_diagnostic", ignore_errors=True)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Check LightRAG initialization status")
+ parser.add_argument(
+ "--demo", action="store_true", help="Run a demonstration with a test instance"
+ )
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ action="store_true",
+ help="Show detailed diagnostic information",
+ )
+
+ args = parser.parse_args()
+
+ if args.demo:
+ asyncio.run(demo())
+ else:
+ print("Run with --demo to see the diagnostic tool in action")
+ print("Or import this module and use check_lightrag_setup() with your instance")
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 979517b5..8ebdf9a2 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -14,7 +14,7 @@ from dataclasses import dataclass
from datetime import datetime
from functools import wraps
from hashlib import md5
-from typing import Any, Protocol, Callable, TYPE_CHECKING, List
+from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional
import numpy as np
from dotenv import load_dotenv
@@ -27,22 +27,80 @@ from lightrag.constants import (
DEFAULT_MAX_FILE_PATH_LENGTH,
)
+# Initialize logger with basic configuration
+logger = logging.getLogger("lightrag")
+logger.propagate = False # prevent log message send to root logger
+logger.setLevel(logging.INFO)
+
+# Add console handler if no handlers exist
+if not logger.handlers:
+ console_handler = logging.StreamHandler()
+ console_handler.setLevel(logging.INFO)
+ formatter = logging.Formatter("%(levelname)s: %(message)s")
+ console_handler.setFormatter(formatter)
+ logger.addHandler(console_handler)
+
+# Set httpx logging level to WARNING
+logging.getLogger("httpx").setLevel(logging.WARNING)
+
# Global import for pypinyin with startup-time logging
try:
import pypinyin
_PYPINYIN_AVAILABLE = True
- logger = logging.getLogger("lightrag")
- logger.info("pypinyin loaded successfully for Chinese pinyin sorting")
+ # logger.info("pypinyin loaded successfully for Chinese pinyin sorting")
except ImportError:
pypinyin = None
_PYPINYIN_AVAILABLE = False
- logger = logging.getLogger("lightrag")
logger.warning(
"pypinyin is not installed. Chinese pinyin sorting will use simple string sorting."
)
+async def safe_vdb_operation_with_exception(
+ operation: Callable,
+ operation_name: str,
+ entity_name: str = "",
+ max_retries: int = 3,
+ retry_delay: float = 0.2,
+ logger_func: Optional[Callable] = None,
+) -> None:
+ """
+ Safely execute vector database operations with retry mechanism and exception handling.
+
+ This function ensures that VDB operations are executed with proper error handling
+ and retry logic. If all retries fail, it raises an exception to maintain data consistency.
+
+ Args:
+ operation: The async operation to execute
+ operation_name: Operation name for logging purposes
+ entity_name: Entity name for logging purposes
+ max_retries: Maximum number of retry attempts
+ retry_delay: Delay between retries in seconds
+ logger_func: Logger function to use for error messages
+
+ Raises:
+ Exception: When operation fails after all retry attempts
+ """
+ log_func = logger_func or logger.warning
+
+ for attempt in range(max_retries):
+ try:
+ await operation()
+ return # Success, return immediately
+ except Exception as e:
+ if attempt >= max_retries - 1:
+ error_msg = f"VDB {operation_name} failed for {entity_name} after {max_retries} attempts: {e}"
+ log_func(error_msg)
+ raise Exception(error_msg) from e
+ else:
+ log_func(
+ f"VDB {operation_name} attempt {attempt + 1} failed for {entity_name}: {e}, retrying..."
+ )
+ if retry_delay > 0:
+ await asyncio.sleep(retry_delay)
+
+
def get_env_value(
env_key: str, default: any, value_type: type = str, special_none: bool = False
) -> any:
@@ -68,6 +126,27 @@ def get_env_value(
if value_type is bool:
return value.lower() in ("true", "1", "yes", "t", "on")
+
+ # Handle list type with JSON parsing
+ if value_type is list:
+ try:
+ import json
+
+ parsed_value = json.loads(value)
+ # Ensure the parsed value is actually a list
+ if isinstance(parsed_value, list):
+ return parsed_value
+ else:
+ logger.warning(
+ f"Environment variable {env_key} is not a valid JSON list, using default"
+ )
+ return default
+ except (json.JSONDecodeError, ValueError) as e:
+ logger.warning(
+ f"Failed to parse {env_key} as JSON list: {e}, using default"
+ )
+ return default
+
try:
return value_type(value)
except (ValueError, TypeError):
@@ -121,15 +200,6 @@ def set_verbose_debug(enabled: bool):
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
-# Initialize logger
-logger = logging.getLogger("lightrag")
-logger.propagate = False # prevent log message send to root loggger
-# Let the main application configure the handlers
-logger.setLevel(logging.INFO)
-
-# Set httpx logging level to WARNING
-logging.getLogger("httpx").setLevel(logging.WARNING)
-
class LightragPathFilter(logging.Filter):
"""Filter for lightrag logger to filter out frequent path access logs"""
@@ -254,6 +324,18 @@ class UnlimitedSemaphore:
pass
+@dataclass
+class TaskState:
+ """Task state tracking for priority queue management"""
+
+ future: asyncio.Future
+ start_time: float
+ execution_start_time: float = None
+ worker_started: bool = False
+ cancellation_requested: bool = False
+ cleanup_done: bool = False
+
+
@dataclass
class EmbeddingFunc:
embedding_dim: int
@@ -323,20 +405,60 @@ def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
return None
-# Custom exception class
+# Custom exception classes
class QueueFullError(Exception):
"""Raised when the queue is full and the wait times out"""
pass
-def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
+class WorkerTimeoutError(Exception):
+ """Worker-level timeout exception with specific timeout information"""
+
+ def __init__(self, timeout_value: float, timeout_type: str = "execution"):
+ self.timeout_value = timeout_value
+ self.timeout_type = timeout_type
+ super().__init__(f"Worker {timeout_type} timeout after {timeout_value}s")
+
+
+class HealthCheckTimeoutError(Exception):
+ """Health Check-level timeout exception"""
+
+ def __init__(self, timeout_value: float, execution_duration: float):
+ self.timeout_value = timeout_value
+ self.execution_duration = execution_duration
+ super().__init__(
+ f"Task forcefully terminated due to execution timeout (>{timeout_value}s, actual: {execution_duration:.1f}s)"
+ )
+
+
+def priority_limit_async_func_call(
+ max_size: int,
+ llm_timeout: float = None,
+ max_execution_timeout: float = None,
+ max_task_duration: float = None,
+ max_queue_size: int = 1000,
+ cleanup_timeout: float = 2.0,
+ queue_name: str = "limit_async",
+):
"""
- Enhanced priority-limited asynchronous function call decorator
+ Enhanced priority-limited asynchronous function call decorator with robust timeout handling
+
+ This decorator provides a comprehensive solution for managing concurrent LLM requests with:
+ - Multi-layer timeout protection (LLM -> Worker -> Health Check -> User)
+ - Task state tracking to prevent race conditions
+ - Enhanced health check system with stuck task detection
+ - Proper resource cleanup and error recovery
Args:
max_size: Maximum number of concurrent calls
max_queue_size: Maximum queue capacity to prevent memory overflow
+ llm_timeout: LLM provider timeout (from global config), used to calculate other timeouts
+ max_execution_timeout: Maximum time for worker to execute function (defaults to llm_timeout + 30s)
+ max_task_duration: Maximum time before health check intervenes (defaults to llm_timeout + 60s)
+ cleanup_timeout: Maximum time to wait for cleanup operations (defaults to 2.0s)
+ queue_name: Optional queue name for logging identification (defaults to "limit_async")
+
Returns:
Decorator function
"""
@@ -345,108 +467,197 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
# Ensure func is callable
if not callable(func):
raise TypeError(f"Expected a callable object, got {type(func)}")
+
+ # Calculate timeout hierarchy if llm_timeout is provided (Dynamic Timeout Calculation)
+ if llm_timeout is not None:
+ nonlocal max_execution_timeout, max_task_duration
+ if max_execution_timeout is None:
+ max_execution_timeout = (
+ llm_timeout + 30
+ ) # LLM timeout + 30s buffer for network delays
+ if max_task_duration is None:
+ max_task_duration = (
+ llm_timeout + 60
+ ) # LLM timeout + 1min buffer for execution phase
+
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
tasks = set()
initialization_lock = asyncio.Lock()
counter = 0
shutdown_event = asyncio.Event()
- initialized = False # Global initialization flag
+ initialized = False
worker_health_check_task = None
- # Track active future objects for cleanup
+ # Enhanced task state management
+ task_states = {} # task_id -> TaskState
+ task_states_lock = asyncio.Lock()
active_futures = weakref.WeakSet()
- reinit_count = 0 # Reinitialization counter to track system health
+ reinit_count = 0
- # Worker function to process tasks in the queue
async def worker():
- """Worker that processes tasks in the priority queue"""
+ """Enhanced worker that processes tasks with proper timeout and state management"""
try:
while not shutdown_event.is_set():
try:
- # Use timeout to get tasks, allowing periodic checking of shutdown signal
+ # Get task from queue with timeout for shutdown checking
try:
(
priority,
count,
- future,
+ task_id,
args,
kwargs,
) = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError:
- # Timeout is just to check shutdown signal, continue to next iteration
continue
- # If future is cancelled, skip execution
- if future.cancelled():
+ # Get task state and mark worker as started
+ async with task_states_lock:
+ if task_id not in task_states:
+ queue.task_done()
+ continue
+ task_state = task_states[task_id]
+ task_state.worker_started = True
+ # Record execution start time when worker actually begins processing
+ task_state.execution_start_time = (
+ asyncio.get_event_loop().time()
+ )
+
+ # Check if task was cancelled before worker started
+ if (
+ task_state.cancellation_requested
+ or task_state.future.cancelled()
+ ):
+ async with task_states_lock:
+ task_states.pop(task_id, None)
queue.task_done()
continue
try:
- # Execute function
- result = await func(*args, **kwargs)
- # If future is not done, set the result
- if not future.done():
- future.set_result(result)
- except asyncio.CancelledError:
- if not future.done():
- future.cancel()
- logger.debug("limit_async: Task cancelled during execution")
- except Exception as e:
- logger.error(
- f"limit_async: Error in decorated function: {str(e)}"
- )
- if not future.done():
- future.set_exception(e)
- finally:
- queue.task_done()
- except Exception as e:
- # Catch all exceptions in worker loop to prevent worker termination
- logger.error(f"limit_async: Critical error in worker: {str(e)}")
- await asyncio.sleep(0.1) # Prevent high CPU usage
- finally:
- logger.debug("limit_async: Worker exiting")
+ # Execute function with timeout protection
+ if max_execution_timeout is not None:
+ result = await asyncio.wait_for(
+ func(*args, **kwargs), timeout=max_execution_timeout
+ )
+ else:
+ result = await func(*args, **kwargs)
- async def health_check():
- """Periodically check worker health status and recover"""
+ # Set result if future is still valid
+ if not task_state.future.done():
+ task_state.future.set_result(result)
+
+ except asyncio.TimeoutError:
+ # Worker-level timeout (max_execution_timeout exceeded)
+ logger.warning(
+ f"{queue_name}: Worker timeout for task {task_id} after {max_execution_timeout}s"
+ )
+ if not task_state.future.done():
+ task_state.future.set_exception(
+ WorkerTimeoutError(
+ max_execution_timeout, "execution"
+ )
+ )
+ except asyncio.CancelledError:
+ # Task was cancelled during execution
+ if not task_state.future.done():
+ task_state.future.cancel()
+ logger.debug(
+ f"{queue_name}: Task {task_id} cancelled during execution"
+ )
+ except Exception as e:
+ # Function execution error
+ logger.error(
+ f"{queue_name}: Error in decorated function for task {task_id}: {str(e)}"
+ )
+ if not task_state.future.done():
+ task_state.future.set_exception(e)
+ finally:
+ # Clean up task state
+ async with task_states_lock:
+ task_states.pop(task_id, None)
+ queue.task_done()
+
+ except Exception as e:
+ # Critical error in worker loop
+ logger.error(
+ f"{queue_name}: Critical error in worker: {str(e)}"
+ )
+ await asyncio.sleep(0.1)
+ finally:
+ logger.debug(f"{queue_name}: Worker exiting")
+
+ async def enhanced_health_check():
+ """Enhanced health check with stuck task detection and recovery"""
nonlocal initialized
try:
while not shutdown_event.is_set():
await asyncio.sleep(5) # Check every 5 seconds
- # No longer acquire lock, directly operate on task set
- # Use a copy of the task set to avoid concurrent modification
+ current_time = asyncio.get_event_loop().time()
+
+ # Detect and handle stuck tasks based on execution start time
+ if max_task_duration is not None:
+ stuck_tasks = []
+ async with task_states_lock:
+ for task_id, task_state in list(task_states.items()):
+ # Only check tasks that have started execution
+ if (
+ task_state.worker_started
+ and task_state.execution_start_time is not None
+ and current_time - task_state.execution_start_time
+ > max_task_duration
+ ):
+ stuck_tasks.append(
+ (
+ task_id,
+ current_time
+ - task_state.execution_start_time,
+ )
+ )
+
+ # Force cleanup of stuck tasks
+ for task_id, execution_duration in stuck_tasks:
+ logger.warning(
+ f"{queue_name}: Detected stuck task {task_id} (execution time: {execution_duration:.1f}s), forcing cleanup"
+ )
+ async with task_states_lock:
+ if task_id in task_states:
+ task_state = task_states[task_id]
+ if not task_state.future.done():
+ task_state.future.set_exception(
+ HealthCheckTimeoutError(
+ max_task_duration, execution_duration
+ )
+ )
+ task_states.pop(task_id, None)
+
+ # Worker recovery logic
current_tasks = set(tasks)
done_tasks = {t for t in current_tasks if t.done()}
tasks.difference_update(done_tasks)
- # Calculate active tasks count
active_tasks_count = len(tasks)
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
logger.info(
- f"limit_async: Creating {workers_needed} new workers"
+ f"{queue_name}: Creating {workers_needed} new workers"
)
new_tasks = set()
for _ in range(workers_needed):
task = asyncio.create_task(worker())
new_tasks.add(task)
task.add_done_callback(tasks.discard)
- # Update task set in one operation
tasks.update(new_tasks)
+
except Exception as e:
- logger.error(f"limit_async: Error in health check: {str(e)}")
+ logger.error(f"{queue_name}: Error in enhanced health check: {str(e)}")
finally:
- logger.debug("limit_async: Health check task exiting")
+ logger.debug(f"{queue_name}: Enhanced health check task exiting")
initialized = False
async def ensure_workers():
- """Ensure worker threads and health check system are available
-
- This function checks if the worker system is already initialized.
- If not, it performs a one-time initialization of all worker threads
- and starts the health check system.
- """
+ """Ensure worker system is initialized with enhanced error handling"""
nonlocal initialized, worker_health_check_task, tasks, reinit_count
if initialized:
@@ -456,45 +667,56 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
if initialized:
return
- # Increment reinitialization counter if this is not the first initialization
if reinit_count > 0:
reinit_count += 1
logger.warning(
- f"limit_async: Reinitializing needed (count: {reinit_count})"
+ f"{queue_name}: Reinitializing system (count: {reinit_count})"
)
else:
- reinit_count = 1 # First initialization
+ reinit_count = 1
- # Check for completed tasks and remove them from the task set
+ # Clean up completed tasks
current_tasks = set(tasks)
done_tasks = {t for t in current_tasks if t.done()}
tasks.difference_update(done_tasks)
- # Log active tasks count during reinitialization
active_tasks_count = len(tasks)
if active_tasks_count > 0 and reinit_count > 1:
logger.warning(
- f"limit_async: {active_tasks_count} tasks still running during reinitialization"
+ f"{queue_name}: {active_tasks_count} tasks still running during reinitialization"
)
- # Create initial worker tasks, only adding the number needed
+ # Create worker tasks
workers_needed = max_size - active_tasks_count
for _ in range(workers_needed):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
- # Start health check
- worker_health_check_task = asyncio.create_task(health_check())
+ # Start enhanced health check
+ worker_health_check_task = asyncio.create_task(enhanced_health_check())
initialized = True
- logger.info(f"limit_async: {workers_needed} new workers initialized")
+ # Log dynamic timeout configuration
+ timeout_info = []
+ if llm_timeout is not None:
+ timeout_info.append(f"Func: {llm_timeout}s")
+ if max_execution_timeout is not None:
+ timeout_info.append(f"Worker: {max_execution_timeout}s")
+ if max_task_duration is not None:
+ timeout_info.append(f"Health Check: {max_task_duration}s")
+
+ timeout_str = (
+ f" (Timeouts: {', '.join(timeout_info)})" if timeout_info else ""
+ )
+ logger.info(
+ f"{queue_name}: {workers_needed} new workers initialized {timeout_str}"
+ )
async def shutdown():
- """Gracefully shut down all workers and the queue"""
- logger.info("limit_async: Shutting down priority queue workers")
+ """Gracefully shut down all workers and cleanup resources"""
+ logger.info(f"{queue_name}: Shutting down priority queue workers")
- # Set the shutdown event
shutdown_event.set()
# Cancel all active futures
@@ -502,15 +724,22 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
if not future.done():
future.cancel()
- # Wait for the queue to empty
+ # Cancel all pending tasks
+ async with task_states_lock:
+ for task_id, task_state in list(task_states.items()):
+ if not task_state.future.done():
+ task_state.future.cancel()
+ task_states.clear()
+
+ # Wait for queue to empty with timeout
try:
await asyncio.wait_for(queue.join(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning(
- "limit_async: Timeout waiting for queue to empty during shutdown"
+ f"{queue_name}: Timeout waiting for queue to empty during shutdown"
)
- # Cancel all worker tasks
+ # Cancel worker tasks
for task in list(tasks):
if not task.done():
task.cancel()
@@ -519,7 +748,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
- # Cancel the health check task
+ # Cancel health check task
if worker_health_check_task and not worker_health_check_task.done():
worker_health_check_task.cancel()
try:
@@ -527,84 +756,120 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
except asyncio.CancelledError:
pass
- logger.info("limit_async: Priority queue workers shutdown complete")
+ logger.info(f"{queue_name}: Priority queue workers shutdown complete")
@wraps(func)
async def wait_func(
*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
):
"""
- Execute the function with priority-based concurrency control
+ Execute function with enhanced priority-based concurrency control and timeout handling
+
Args:
*args: Positional arguments passed to the function
_priority: Call priority (lower values have higher priority)
- _timeout: Maximum time to wait for function completion (in seconds)
+ _timeout: Maximum time to wait for completion (in seconds, none means determinded by max_execution_timeout of the queue)
_queue_timeout: Maximum time to wait for entering the queue (in seconds)
**kwargs: Keyword arguments passed to the function
+
Returns:
The result of the function call
+
Raises:
- TimeoutError: If the function call times out
+ TimeoutError: If the function call times out at any level
QueueFullError: If the queue is full and waiting times out
Any exception raised by the decorated function
"""
- # Ensure worker system is initialized
await ensure_workers()
- # Create a future for the result
+ # Generate unique task ID
+ task_id = f"{id(asyncio.current_task())}_{asyncio.get_event_loop().time()}"
future = asyncio.Future()
- active_futures.add(future)
- nonlocal counter
- async with initialization_lock:
- current_count = counter # Use local variable to avoid race conditions
- counter += 1
+ # Create task state
+ task_state = TaskState(
+ future=future, start_time=asyncio.get_event_loop().time()
+ )
- # Try to put the task into the queue, supporting timeout
try:
- if _queue_timeout is not None:
- # Use timeout to wait for queue space
- try:
+ # Register task state
+ async with task_states_lock:
+ task_states[task_id] = task_state
+
+ active_futures.add(future)
+
+ # Get counter for FIFO ordering
+ nonlocal counter
+ async with initialization_lock:
+ current_count = counter
+ counter += 1
+
+ # Queue the task with timeout handling
+ try:
+ if _queue_timeout is not None:
await asyncio.wait_for(
- # current_count is used to ensure FIFO order
- queue.put((_priority, current_count, future, args, kwargs)),
+ queue.put(
+ (_priority, current_count, task_id, args, kwargs)
+ ),
timeout=_queue_timeout,
)
- except asyncio.TimeoutError:
- raise QueueFullError(
- f"Queue full, timeout after {_queue_timeout} seconds"
+ else:
+ await queue.put(
+ (_priority, current_count, task_id, args, kwargs)
)
- else:
- # No timeout, may wait indefinitely
- # current_count is used to ensure FIFO order
- await queue.put((_priority, current_count, future, args, kwargs))
- except Exception as e:
- # Clean up the future
- if not future.done():
- future.set_exception(e)
- active_futures.discard(future)
- raise
+ except asyncio.TimeoutError:
+ raise QueueFullError(
+ f"{queue_name}: Queue full, timeout after {_queue_timeout} seconds"
+ )
+ except Exception as e:
+ # Clean up on queue error
+ if not future.done():
+ future.set_exception(e)
+ raise
- try:
- # Wait for the result, optional timeout
- if _timeout is not None:
- try:
+ # Wait for result with timeout handling
+ try:
+ if _timeout is not None:
return await asyncio.wait_for(future, _timeout)
- except asyncio.TimeoutError:
- # Cancel the future
- if not future.done():
- future.cancel()
- raise TimeoutError(
- f"limit_async: Task timed out after {_timeout} seconds"
- )
- else:
- # Wait for the result without timeout
- return await future
- finally:
- # Clean up the future reference
- active_futures.discard(future)
+ else:
+ return await future
+ except asyncio.TimeoutError:
+ # This is user-level timeout (asyncio.wait_for caused)
+ # Mark cancellation request
+ async with task_states_lock:
+ if task_id in task_states:
+ task_states[task_id].cancellation_requested = True
- # Add the shutdown method to the decorated function
+ # Cancel future
+ if not future.done():
+ future.cancel()
+
+ # Wait for worker cleanup with timeout
+ cleanup_start = asyncio.get_event_loop().time()
+ while (
+ task_id in task_states
+ and asyncio.get_event_loop().time() - cleanup_start
+ < cleanup_timeout
+ ):
+ await asyncio.sleep(0.1)
+
+ raise TimeoutError(
+ f"{queue_name}: User timeout after {_timeout} seconds"
+ )
+ except WorkerTimeoutError as e:
+ # This is Worker-level timeout, directly propagate exception information
+ raise TimeoutError(f"{queue_name}: {str(e)}")
+ except HealthCheckTimeoutError as e:
+ # This is Health Check-level timeout, directly propagate exception information
+ raise TimeoutError(f"{queue_name}: {str(e)}")
+
+ finally:
+ # Ensure cleanup
+ active_futures.discard(future)
+ async with task_states_lock:
+ task_states.pop(task_id, None)
+
+ # Add shutdown method to decorated function
wait_func.shutdown = shutdown
return wait_func
@@ -735,19 +1000,6 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
return [r.strip() for r in results if r.strip()]
-# Refer the utils functions of the official GraphRAG implementation:
-# https://github.com/microsoft/graphrag
-def clean_str(input: Any) -> str:
- """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
- # If we get non-string input, just give it back
- if not isinstance(input, str):
- return input
-
- result = html.unescape(input.strip())
- # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
- return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
-
-
def is_float_regex(value: str) -> bool:
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
@@ -1386,8 +1638,11 @@ async def update_chunk_cache_list(
def remove_think_tags(text: str) -> str:
- """Remove tags from the text"""
- return re.sub(r"^(.*?|)", "", text, flags=re.DOTALL).strip()
+ """Remove ... tags from the text
+ Remove orphon ... tags from the text also"""
+ return re.sub(
+ r"^(.*?|.*)", "", text, flags=re.DOTALL
+ ).strip()
async def use_llm_func_with_cache(
@@ -1471,6 +1726,7 @@ async def use_llm_func_with_cache(
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(safe_input_text, **kwargs)
+
res = remove_think_tags(res)
if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
@@ -1498,8 +1754,14 @@ async def use_llm_func_with_cache(
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
- logger.info(f"Call LLM function with query text length: {len(safe_input_text)}")
- res = await use_llm_func(safe_input_text, **kwargs)
+ try:
+ res = await use_llm_func(safe_input_text, **kwargs)
+ except Exception as e:
+ # Add [LLM func] prefix to error message
+ error_msg = f"[LLM func] {str(e)}"
+ # Re-raise with the same exception type but modified message
+ raise type(e)(error_msg) from e
+
return remove_think_tags(res)
@@ -1519,29 +1781,82 @@ def get_content_summary(content: str, max_length: int = 250) -> str:
return content[:max_length] + "..."
-def normalize_extracted_info(name: str, is_entity=False) -> str:
+def sanitize_and_normalize_extracted_text(
+ input_text: str, remove_inner_quotes=False
+) -> str:
+ """Santitize and normalize extracted text
+ Args:
+ input_text: text string to be processed
+ is_name: whether the input text is a entity or relation name
+
+ Returns:
+ Santitized and normalized text string
+ """
+ safe_input_text = sanitize_text_for_encoding(input_text)
+ if safe_input_text:
+ normalized_text = normalize_extracted_info(
+ safe_input_text, remove_inner_quotes=remove_inner_quotes
+ )
+ return normalized_text
+ return ""
+
+
+def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
"""Normalize entity/relation names and description with the following rules:
- 1. Remove spaces between Chinese characters
- 2. Remove spaces between Chinese characters and English letters/numbers
- 3. Preserve spaces within English text and numbers
- 4. Replace Chinese parentheses with English parentheses
- 5. Replace Chinese dash with English dash
- 6. Remove English quotation marks from the beginning and end of the text
- 7. Remove English quotation marks in and around chinese
- 8. Remove Chinese quotation marks
+ - Clean HTML tags (paragraph and line break tags)
+ - Convert Chinese symbols to English symbols
+ - Remove spaces between Chinese characters
+ - Remove spaces between Chinese characters and English letters/numbers
+ - Preserve spaces within English text and numbers
+ - Replace Chinese parentheses with English parentheses
+ - Replace Chinese dash with English dash
+ - Remove English quotation marks from the beginning and end of the text
+ - Remove English quotation marks in and around chinese
+ - Remove Chinese quotation marks
+ - Filter out short numeric-only text (length < 3 and only digits/dots)
+ - remove_inner_quotes = True
+ remove Chinese quotes
+ remove English queotes in and around chinese
+ Convert non-breaking spaces to regular spaces
+ Convert narrow non-breaking spaces after non-digits to regular spaces
Args:
name: Entity name to normalize
+ is_entity: Whether this is an entity name (affects quote handling)
Returns:
Normalized entity name
"""
+ # Clean HTML tags - remove paragraph and line break tags
+ name = re.sub(r"
||
", "", name, flags=re.IGNORECASE)
+ name = re.sub(r"|
|
", "", name, flags=re.IGNORECASE)
+
+ # Chinese full-width letters to half-width (A-Z, a-z)
+ name = name.translate(
+ str.maketrans(
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ )
+ )
+
+ # Chinese full-width numbers to half-width
+ name = name.translate(str.maketrans("0123456789", "0123456789"))
+
+ # Chinese full-width symbols to half-width
+ name = name.replace("-", "-") # Chinese minus
+ name = name.replace("+", "+") # Chinese plus
+ name = name.replace("/", "/") # Chinese slash
+ name = name.replace("*", "*") # Chinese asterisk
+
# Replace Chinese parentheses with English parentheses
name = name.replace("(", "(").replace(")", ")")
- # Replace Chinese dash with English dash
+ # Replace Chinese dash with English dash (additional patterns)
name = name.replace("—", "-").replace("-", "-")
+ # Chinese full-width space to regular space (after other replacements)
+ name = name.replace(" ", " ")
+
# Use regex to remove spaces between Chinese characters
# Regex explanation:
# (?<=[\u4e00-\u9fa5]): Positive lookbehind for Chinese character
@@ -1557,18 +1872,60 @@ def normalize_extracted_info(name: str, is_entity=False) -> str:
r"(?<=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])\s+(?=[\u4e00-\u9fa5])", "", name
)
- # Remove English quotation marks from the beginning and end
- if len(name) >= 2 and name.startswith('"') and name.endswith('"'):
- name = name[1:-1]
- if len(name) >= 2 and name.startswith("'") and name.endswith("'"):
- name = name[1:-1]
+ # Remove outer quotes
+ if len(name) >= 2:
+ # Handle double quotes
+ if name.startswith('"') and name.endswith('"'):
+ inner_content = name[1:-1]
+ if '"' not in inner_content: # No double quotes inside
+ name = inner_content
- if is_entity:
- # remove Chinese quotes
+ # Handle single quotes
+ if name.startswith("'") and name.endswith("'"):
+ inner_content = name[1:-1]
+ if "'" not in inner_content: # No single quotes inside
+ name = inner_content
+
+ # Handle Chinese-style double quotes
+ if name.startswith("“") and name.endswith("”"):
+ inner_content = name[1:-1]
+ if "“" not in inner_content and "”" not in inner_content:
+ name = inner_content
+ if name.startswith("‘") and name.endswith("’"):
+ inner_content = name[1:-1]
+ if "‘" not in inner_content and "’" not in inner_content:
+ name = inner_content
+
+ if remove_inner_quotes:
+ # Remove Chinese quotes
name = name.replace("“", "").replace("”", "").replace("‘", "").replace("’", "")
- # remove English queotes in and around chinese
+ # Remove English queotes in and around chinese
name = re.sub(r"['\"]+(?=[\u4e00-\u9fa5])", "", name)
name = re.sub(r"(?<=[\u4e00-\u9fa5])['\"]+", "", name)
+ # Convert non-breaking space to regular space
+ name = name.replace("\u00a0", " ")
+ # Convert narrow non-breaking space to regular space when after non-digits
+ name = re.sub(r"(?<=[^\d])\u202F", " ", name)
+
+ # Remove spaces from the beginning and end of the text
+ name = name.strip()
+
+ # Filter out pure numeric content with length < 3
+ if len(name) < 3 and re.match(r"^[0-9]+$", name):
+ return ""
+
+ def should_filter_by_dots(text):
+ """
+ Check if the string consists only of dots and digits, with at least one dot
+ Filter cases include: 1.2.3, 12.3, .123, 123., 12.3., .1.23 etc.
+ """
+ return all(c.isdigit() or c == "." for c in text) and "." in text
+
+ if len(name) < 6 and should_filter_by_dots(name):
+ # Filter out mixed numeric and dot content with length < 6
+ return ""
+ # Filter out mixed numeric and dot content with length < 6, requiring at least one dot
+ return ""
return name
@@ -1577,9 +1934,11 @@ def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str:
"""Sanitize text to ensure safe UTF-8 encoding by removing or replacing problematic characters.
This function handles:
- - Surrogate characters (the main cause of the encoding error)
+ - Surrogate characters (the main cause of encoding errors)
- Other invalid Unicode sequences
- Control characters that might cause issues
+ - Unescape HTML escapes
+ - Remove control characters
- Whitespace trimming
Args:
@@ -1588,10 +1947,10 @@ def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str:
Returns:
Sanitized text that can be safely encoded as UTF-8
- """
- if not isinstance(text, str):
- return str(text)
+ Raises:
+ ValueError: When text contains uncleanable encoding issues that cannot be safely processed
+ """
if not text:
return text
@@ -1624,7 +1983,7 @@ def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str:
else:
sanitized += char
- # Additional cleanup: remove null bytes and other control characters that might cause issues
+ # Additional cleanup: remove null bytes and other control characters that might cause issues
# (but preserve common whitespace like \t, \n, \r)
sanitized = re.sub(
r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", replacement_char, sanitized
@@ -1633,37 +1992,30 @@ def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str:
# Test final encoding to ensure it's safe
sanitized.encode("utf-8")
- return sanitized
+ # Unescape HTML escapes
+ sanitized = html.unescape(sanitized)
+
+ # Remove control characters
+ sanitized = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", sanitized)
+
+ return sanitized.strip()
except UnicodeEncodeError as e:
- logger.warning(
- f"Text sanitization: UnicodeEncodeError encountered, applying aggressive cleaning: {str(e)[:100]}"
- )
-
- # Aggressive fallback: encode with error handling
- try:
- # Use 'replace' error handling to substitute problematic characters
- safe_bytes = text.encode("utf-8", errors="replace")
- sanitized = safe_bytes.decode("utf-8")
-
- # Additional cleanup
- sanitized = re.sub(
- r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", replacement_char, sanitized
- )
-
- return sanitized
-
- except Exception as fallback_error:
- logger.error(
- f"Text sanitization: Aggressive fallback failed: {str(fallback_error)}"
- )
- # Last resort: return a safe placeholder
- return f"[TEXT_ENCODING_ERROR: {len(text)} characters]"
+ # Critical change: Don't return placeholder, raise exception for caller to handle
+ error_msg = f"Text contains uncleanable UTF-8 encoding issues: {str(e)[:100]}"
+ logger.error(f"Text sanitization failed: {error_msg}")
+ raise ValueError(error_msg) from e
except Exception as e:
logger.error(f"Text sanitization: Unexpected error: {str(e)}")
- # Return original text if no encoding issues detected
- return text
+ # For other exceptions, if no encoding issues detected, return original text
+ try:
+ text.encode("utf-8")
+ return text
+ except UnicodeEncodeError:
+ raise ValueError(
+ f"Text sanitization failed with unexpected error: {str(e)}"
+ ) from e
def check_storage_env_vars(storage_name: str) -> None:
@@ -1774,6 +2126,7 @@ async def pick_by_vector_similarity(
num_of_chunks: int,
entity_info: list[dict[str, Any]],
embedding_func: callable,
+ query_embedding=None,
) -> list[str]:
"""
Vector similarity-based text chunk selection algorithm.
@@ -1818,11 +2171,19 @@ async def pick_by_vector_similarity(
all_chunk_ids = list(all_chunk_ids)
try:
- # Get query embedding
- query_embedding = await embedding_func([query])
- query_embedding = query_embedding[
- 0
- ] # Extract first embedding from batch result
+ # Use pre-computed query embedding if provided, otherwise compute it
+ if query_embedding is None:
+ query_embedding = await embedding_func([query])
+ query_embedding = query_embedding[
+ 0
+ ] # Extract first embedding from batch result
+ logger.debug(
+ "Computed query embedding for vector similarity chunk selection"
+ )
+ else:
+ logger.debug(
+ "Using pre-computed query embedding for vector similarity chunk selection"
+ )
# Get chunk embeddings from vector database
chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)
@@ -1969,17 +2330,50 @@ async def apply_rerank_if_enabled(
return retrieved_docs
try:
- # Apply reranking - let rerank_model_func handle top_k internally
- reranked_docs = await rerank_func(
+ # Extract document content for reranking
+ document_texts = []
+ for doc in retrieved_docs:
+ # Try multiple possible content fields
+ content = (
+ doc.get("content")
+ or doc.get("text")
+ or doc.get("chunk_content")
+ or doc.get("document")
+ or str(doc)
+ )
+ document_texts.append(content)
+
+ # Call the new rerank function that returns index-based results
+ rerank_results = await rerank_func(
query=query,
- documents=retrieved_docs,
+ documents=document_texts,
top_n=top_n,
)
- if reranked_docs and len(reranked_docs) > 0:
- if len(reranked_docs) > top_n:
- reranked_docs = reranked_docs[:top_n]
- logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks")
- return reranked_docs
+
+ # Process rerank results based on return format
+ if rerank_results and len(rerank_results) > 0:
+ # Check if results are in the new index-based format
+ if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]:
+ # New format: [{"index": 0, "relevance_score": 0.85}, ...]
+ reranked_docs = []
+ for result in rerank_results:
+ index = result["index"]
+ relevance_score = result["relevance_score"]
+
+ # Get original document and add rerank score
+ if 0 <= index < len(retrieved_docs):
+ doc = retrieved_docs[index].copy()
+ doc["rerank_score"] = relevance_score
+ reranked_docs.append(doc)
+
+ logger.info(
+ f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks"
+ )
+ return reranked_docs
+ else:
+ # Legacy format: assume it's already reranked documents
+ logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks")
+ return rerank_results[:top_n] if top_n else rerank_results
else:
logger.warning("Rerank returned empty results, using original chunks")
return retrieved_docs
@@ -2018,13 +2412,6 @@ async def process_chunks_unified(
# 1. Apply reranking if enabled and query is provided
if query_param.enable_rerank and query and unique_chunks:
- # 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
- chunk_ids = {}
- for chunk in unique_chunks:
- chunk_id = chunk.get("chunk_id")
- if chunk_id:
- chunk_ids[id(chunk)] = chunk_id
-
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
@@ -2034,11 +2421,6 @@ async def process_chunks_unified(
top_n=rerank_top_k,
)
- # 恢复 chunk_id 字段
- for chunk in unique_chunks:
- if id(chunk) in chunk_ids:
- chunk["chunk_id"] = chunk_ids[id(chunk)]
-
# 2. Filter by minimum rerank score if reranking is enabled
if query_param.enable_rerank and unique_chunks:
min_rerank_score = global_config.get("min_rerank_score", 0.5)
@@ -2086,13 +2468,6 @@ async def process_chunks_unified(
original_count = len(unique_chunks)
- # Keep chunk_id field, cause truncate_list_by_token_size will lose it
- chunk_ids_map = {}
- for i, chunk in enumerate(unique_chunks):
- chunk_id = chunk.get("chunk_id")
- if chunk_id:
- chunk_ids_map[i] = chunk_id
-
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: json.dumps(x, ensure_ascii=False),
@@ -2100,11 +2475,6 @@ async def process_chunks_unified(
tokenizer=tokenizer,
)
- # restore chunk_id feiled
- for i, chunk in enumerate(unique_chunks):
- if i in chunk_ids_map:
- chunk["chunk_id"] = chunk_ids_map[i]
-
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"