Merge branch 'main' into add-Memgraph-graph-db

This commit is contained in:
yangdx 2025-07-04 23:53:07 +08:00
commit a69194c079
28 changed files with 3042 additions and 793 deletions

View file

@ -58,6 +58,8 @@ SUMMARY_LANGUAGE=English
# FORCE_LLM_SUMMARY_ON_MERGE=6
### Max tokens for entity/relations description after merge
# MAX_TOKEN_SUMMARY=500
### Maximum number of entity extraction attempts for ambiguous content
# MAX_GLEANING=1
### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
# MAX_PARALLEL_INSERT=2
@ -112,15 +114,6 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
### TiDB Configuration (Deprecated)
# TIDB_HOST=localhost
# TIDB_PORT=4000
# TIDB_USER=your_username
# TIDB_PASSWORD='your_password'
# TIDB_DATABASE=your_database
### separating all data from difference Lightrag instances(deprecating)
# TIDB_WORKSPACE=default
### PostgreSQL Configuration
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
@ -128,7 +121,7 @@ POSTGRES_USER=your_username
POSTGRES_PASSWORD='your_password'
POSTGRES_DATABASE=your_database
POSTGRES_MAX_CONNECTIONS=12
### separating all data from difference Lightrag instances(deprecating)
### separating all data from difference Lightrag instances
# POSTGRES_WORKSPACE=default
### Neo4j Configuration
@ -144,14 +137,15 @@ NEO4J_PASSWORD='your_password'
# AGE_POSTGRES_PORT=8529
# AGE Graph Name(apply to PostgreSQL and independent AGM)
### AGE_GRAPH_NAME is precated
### AGE_GRAPH_NAME is deprecated
# AGE_GRAPH_NAME=lightrag
### MongoDB Configuration
MONGO_URI=mongodb://root:root@localhost:27017/
MONGO_DATABASE=LightRAG
### separating all data from difference Lightrag instances(deprecating)
# MONGODB_GRAPH=false
### separating all data from difference Lightrag instances
# MONGODB_WORKSPACE=default
### Milvus Configuration
MILVUS_URI=http://localhost:19530

View file

@ -11,9 +11,74 @@ This example shows how to:
import os
import argparse
import asyncio
import logging
import logging.config
from pathlib import Path
# Add project root directory to Python path
import sys
sys.path.append(str(Path(__file__).parent.parent))
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
from raganything.raganything import RAGAnything
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
from raganything import RAGAnything, RAGAnythingConfig
def configure_logging():
"""Configure logging for the application"""
# Get log directory path from environment variable or use current directory
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "raganything_example.log"))
print(f"\nRAGAnything example log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
},
}
)
# Set the logger level to INFO
logger.setLevel(logging.INFO)
# Enable verbose debug if needed
set_verbose_debug(os.getenv("VERBOSE", "false").lower() == "true")
async def process_with_rag(
@ -31,15 +96,21 @@ async def process_with_rag(
output_dir: Output directory for RAG results
api_key: OpenAI API key
base_url: Optional base URL for API
working_dir: Working directory for RAG storage
"""
try:
# Initialize RAGAnything
rag = RAGAnything(
working_dir=working_dir,
llm_model_func=lambda prompt,
system_prompt=None,
history_messages=[],
**kwargs: openai_complete_if_cache(
# Create RAGAnything configuration
config = RAGAnythingConfig(
working_dir=working_dir or "./rag_storage",
mineru_parse_method="auto",
enable_image_processing=True,
enable_table_processing=True,
enable_equation_processing=True,
)
# Define LLM model function
def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
return openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
@ -47,81 +118,123 @@ async def process_with_rag(
api_key=api_key,
base_url=base_url,
**kwargs,
),
vision_model_func=lambda prompt,
system_prompt=None,
history_messages=[],
image_data=None,
**kwargs: openai_complete_if_cache(
"gpt-4o",
"",
system_prompt=None,
history_messages=[],
messages=[
{"role": "system", "content": system_prompt}
if system_prompt
else None,
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_data}"
},
},
],
}
if image_data
else {"role": "user", "content": prompt},
],
api_key=api_key,
base_url=base_url,
**kwargs,
)
if image_data
else openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=api_key,
base_url=base_url,
**kwargs,
),
embedding_func=EmbeddingFunc(
embedding_dim=3072,
max_token_size=8192,
func=lambda texts: openai_embed(
texts,
model="text-embedding-3-large",
# Define vision model function for image processing
def vision_model_func(
prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs
):
if image_data:
return openai_complete_if_cache(
"gpt-4o",
"",
system_prompt=None,
history_messages=[],
messages=[
{"role": "system", "content": system_prompt}
if system_prompt
else None,
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_data}"
},
},
],
}
if image_data
else {"role": "user", "content": prompt},
],
api_key=api_key,
base_url=base_url,
),
**kwargs,
)
else:
return llm_model_func(prompt, system_prompt, history_messages, **kwargs)
# Define embedding function
embedding_func = EmbeddingFunc(
embedding_dim=3072,
max_token_size=8192,
func=lambda texts: openai_embed(
texts,
model="text-embedding-3-large",
api_key=api_key,
base_url=base_url,
),
)
# Initialize RAGAnything with new dataclass structure
rag = RAGAnything(
config=config,
llm_model_func=llm_model_func,
vision_model_func=vision_model_func,
embedding_func=embedding_func,
)
# Process document
await rag.process_document_complete(
file_path=file_path, output_dir=output_dir, parse_method="auto"
)
# Example queries
queries = [
# Example queries - demonstrating different query approaches
logger.info("\nQuerying processed document:")
# 1. Pure text queries using aquery()
text_queries = [
"What is the main content of the document?",
"Describe the images and figures in the document",
"Tell me about the experimental results and data tables",
"What are the key topics discussed?",
]
print("\nQuerying processed document:")
for query in queries:
print(f"\nQuery: {query}")
result = await rag.query_with_multimodal(query, mode="hybrid")
print(f"Answer: {result}")
for query in text_queries:
logger.info(f"\n[Text Query]: {query}")
result = await rag.aquery(query, mode="hybrid")
logger.info(f"Answer: {result}")
# 2. Multimodal query with specific multimodal content using aquery_with_multimodal()
logger.info(
"\n[Multimodal Query]: Analyzing performance data in context of document"
)
multimodal_result = await rag.aquery_with_multimodal(
"Compare this performance data with any similar results mentioned in the document",
multimodal_content=[
{
"type": "table",
"table_data": """Method,Accuracy,Processing_Time
RAGAnything,95.2%,120ms
Traditional_RAG,87.3%,180ms
Baseline,82.1%,200ms""",
"table_caption": "Performance comparison results",
}
],
mode="hybrid",
)
logger.info(f"Answer: {multimodal_result}")
# 3. Another multimodal query with equation content
logger.info("\n[Multimodal Query]: Mathematical formula analysis")
equation_result = await rag.aquery_with_multimodal(
"Explain this formula and relate it to any mathematical concepts in the document",
multimodal_content=[
{
"type": "equation",
"latex": "F1 = 2 \\cdot \\frac{precision \\cdot recall}{precision + recall}",
"equation_caption": "F1-score calculation formula",
}
],
mode="hybrid",
)
logger.info(f"Answer: {equation_result}")
except Exception as e:
print(f"Error processing with RAG: {str(e)}")
logger.error(f"Error processing with RAG: {str(e)}")
import traceback
logger.error(traceback.format_exc())
def main():
@ -135,12 +248,20 @@ def main():
"--output", "-o", default="./output", help="Output directory path"
)
parser.add_argument(
"--api-key", required=True, help="OpenAI API key for RAG processing"
"--api-key",
default=os.getenv("OPENAI_API_KEY"),
help="OpenAI API key (defaults to OPENAI_API_KEY env var)",
)
parser.add_argument("--base-url", help="Optional base URL for API")
args = parser.parse_args()
# Check if API key is provided
if not args.api_key:
logger.error("Error: OpenAI API key is required")
logger.error("Set OPENAI_API_KEY environment variable or use --api-key option")
return
# Create output directory if specified
if args.output:
os.makedirs(args.output, exist_ok=True)
@ -154,4 +275,12 @@ def main():
if __name__ == "__main__":
# Configure logging first
configure_logging()
print("RAGAnything Example")
print("=" * 30)
print("Processing document with multimodal RAG pipeline")
print("=" * 30)
main()

View file

@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
embedding_func=None,
)
# Get all cache data using the new flattened structure
all_data = await from_llm_response_cache.get_all()
# Convert flattened data to hierarchical structure for JsonKVStorage
kv = {}
for c_id in await from_llm_response_cache.all_keys():
print(f"Copying {c_id}")
workspace = c_id["workspace"]
mode = c_id["mode"]
_id = c_id["id"]
postgres_db.workspace = workspace
obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
if mode not in kv:
kv[mode] = {}
kv[mode][_id] = obj[_id]
print(f"Object {obj}")
for flattened_key, cache_entry in all_data.items():
# Parse flattened key: {mode}:{cache_type}:{hash}
parts = flattened_key.split(":", 2)
if len(parts) == 3:
mode, cache_type, hash_value = parts
if mode not in kv:
kv[mode] = {}
kv[mode][hash_value] = cache_entry
print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
else:
print(f"Skipping invalid key format: {flattened_key}")
await to_llm_response_cache.upsert(kv)
await to_llm_response_cache.index_done_callback()
print("Mission accomplished!")
@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
db=postgres_db,
)
for mode in await from_llm_response_cache.all_keys():
print(f"Copying {mode}")
caches = await from_llm_response_cache.get_by_id(mode)
for k, v in caches.items():
item = {mode: {k: v}}
print(f"\tCopying {item}")
await to_llm_response_cache.upsert(item)
# Get all cache data from JsonKVStorage (hierarchical structure)
all_data = await from_llm_response_cache.get_all()
# Convert hierarchical data to flattened structure for PGKVStorage
flattened_data = {}
for mode, mode_data in all_data.items():
print(f"Processing mode: {mode}")
for hash_value, cache_entry in mode_data.items():
# Determine cache_type from cache entry or use default
cache_type = cache_entry.get("cache_type", "extract")
# Create flattened key: {mode}:{cache_type}:{hash}
flattened_key = f"{mode}:{cache_type}:{hash_value}"
flattened_data[flattened_key] = cache_entry
print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
# Upsert the flattened data
await to_llm_response_cache.upsert(flattened_data)
print("Mission accomplished!")
if __name__ == "__main__":

View file

@ -1 +1 @@
__api_version__ = "0176"
__api_version__ = "0178"

View file

@ -62,6 +62,51 @@ router = APIRouter(
temp_prefix = "__tmp__"
def sanitize_filename(filename: str, input_dir: Path) -> str:
"""
Sanitize uploaded filename to prevent Path Traversal attacks.
Args:
filename: The original filename from the upload
input_dir: The target input directory
Returns:
str: Sanitized filename that is safe to use
Raises:
HTTPException: If the filename is unsafe or invalid
"""
# Basic validation
if not filename or not filename.strip():
raise HTTPException(status_code=400, detail="Filename cannot be empty")
# Remove path separators and traversal sequences
clean_name = filename.replace("/", "").replace("\\", "")
clean_name = clean_name.replace("..", "")
# Remove control characters and null bytes
clean_name = "".join(c for c in clean_name if ord(c) >= 32 and c != "\x7f")
# Remove leading/trailing whitespace and dots
clean_name = clean_name.strip().strip(".")
# Check if anything is left after sanitization
if not clean_name:
raise HTTPException(
status_code=400, detail="Invalid filename after sanitization"
)
# Verify the final path stays within the input directory
try:
final_path = (input_dir / clean_name).resolve()
if not final_path.is_relative_to(input_dir.resolve()):
raise HTTPException(status_code=400, detail="Unsafe filename detected")
except (OSError, ValueError):
raise HTTPException(status_code=400, detail="Invalid filename")
return clean_name
class ScanResponse(BaseModel):
"""Response model for document scanning operation
@ -783,7 +828,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
try:
new_files = doc_manager.scan_directory_for_new_files()
total_files = len(new_files)
logger.info(f"Found {total_files} new files to index.")
logger.info(f"Found {total_files} files to index.")
if not new_files:
return
@ -816,8 +861,13 @@ async def background_delete_documents(
successful_deletions = []
failed_deletions = []
# Set pipeline status to busy for deletion
# Double-check pipeline status before proceeding
async with pipeline_status_lock:
if pipeline_status.get("busy", False):
logger.warning("Error: Unexpected pipeline busy state, aborting deletion.")
return # Abort deletion operation
# Set pipeline status to busy for deletion
pipeline_status.update(
{
"busy": True,
@ -926,13 +976,26 @@ async def background_delete_documents(
async with pipeline_status_lock:
pipeline_status["history_messages"].append(error_msg)
finally:
# Final summary
# Final summary and check for pending requests
async with pipeline_status_lock:
pipeline_status["busy"] = False
completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed"
pipeline_status["latest_message"] = completion_msg
pipeline_status["history_messages"].append(completion_msg)
# Check if there are pending document indexing requests
has_pending_request = pipeline_status.get("request_pending", False)
# If there are pending requests, start document processing pipeline
if has_pending_request:
try:
logger.info(
"Processing pending document indexing requests after deletion"
)
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logger.error(f"Error processing pending documents after deletion: {e}")
def create_document_routes(
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
@ -986,18 +1049,21 @@ def create_document_routes(
HTTPException: If the file type is not supported (400) or other errors occur (500).
"""
try:
if not doc_manager.is_supported_file(file.filename):
# Sanitize filename to prevent Path Traversal attacks
safe_filename = sanitize_filename(file.filename, doc_manager.input_dir)
if not doc_manager.is_supported_file(safe_filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
file_path = doc_manager.input_dir / safe_filename
# Check if file already exists
if file_path.exists():
return InsertResponse(
status="duplicated",
message=f"File '{file.filename}' already exists in the input directory.",
message=f"File '{safe_filename}' already exists in the input directory.",
)
with open(file_path, "wb") as buffer:
@ -1008,7 +1074,7 @@ def create_document_routes(
return InsertResponse(
status="success",
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.",
)
except Exception as e:
logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")

View file

@ -234,7 +234,7 @@ class OllamaAPI:
@self.router.get("/version", dependencies=[Depends(combined_auth)])
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4")
return OllamaVersionResponse(version="0.9.3")
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
async def get_tags():
@ -244,9 +244,9 @@ class OllamaAPI:
{
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
@ -337,7 +337,10 @@ class OllamaAPI:
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
@ -377,6 +380,7 @@ class OllamaAPI:
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": f"\n\nError: {error_msg}",
"error": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
@ -385,6 +389,7 @@ class OllamaAPI:
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
@ -399,7 +404,10 @@ class OllamaAPI:
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": "",
"done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
@ -444,6 +452,8 @@ class OllamaAPI:
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": str(response_text),
"done": True,
"done_reason": "stop",
"context": [],
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
@ -557,6 +567,12 @@ class OllamaAPI:
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done_reason": "stop",
"done": True,
"total_duration": total_time,
"load_duration": 0,
@ -605,6 +621,7 @@ class OllamaAPI:
"content": f"\n\nError: {error_msg}",
"images": None,
},
"error": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
@ -613,6 +630,11 @@ class OllamaAPI:
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": "",
"images": None,
},
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
@ -633,6 +655,7 @@ class OllamaAPI:
"content": "",
"images": None,
},
"done_reason": "stop",
"done": True,
"total_duration": total_time,
"load_duration": 0,
@ -697,6 +720,7 @@ class OllamaAPI:
"content": str(response_text),
"images": None,
},
"done_reason": "stop",
"done": True,
"total_duration": total_time,
"load_duration": 0,

View file

@ -183,6 +183,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"
elif response is None:
# Handle None response (e.g., when only_need_context=True but no context found)
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
else:
# If it's an async generator, send chunks one by one
try:

View file

@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC):
@dataclass
class BaseGraphStorage(StorageNameSpace, ABC):
"""All operations related to edges in graph should be undirected."""
embedding_func: EmbeddingFunc
@abstractmethod
@ -468,17 +470,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
# Default implementation iterates through all nodes, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_nodes = []
all_labels = await self.get_all_labels()
for label in all_labels:
node = await self.get_node(label)
if node and "source_id" in node:
source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
all_nodes.append(node)
return all_nodes
@abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
@ -643,6 +634,8 @@ class DocProcessingStatus:
"""ISO format timestamp when document was last updated"""
chunks_count: int | None = None
"""Number of chunks after splitting, used for processing"""
chunks_list: list[str] | None = field(default_factory=list)
"""List of chunk IDs associated with this document, used for deletion"""
error: str | None = None
"""Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict)

View file

@ -7,6 +7,7 @@ consistency and makes maintenance easier.
"""
# Default values for environment variables
DEFAULT_MAX_GLEANING = 1
DEFAULT_MAX_TOKEN_SUMMARY = 500
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
DEFAULT_WOKERS = 2

View file

@ -26,11 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"MongoVectorDBStorage",
# "ChromaVectorDBStorage",
# "TiDBVectorDBStorage",
],
"required_methods": ["query", "upsert"],
@ -38,6 +38,7 @@ STORAGE_IMPLEMENTATIONS = {
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"RedisDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
@ -81,6 +82,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"RedisDocStatusStorage": ["REDIS_URI"],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
@ -98,6 +100,7 @@ STORAGES = {
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"RedisDocStatusStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
# "TiDBKVStorage": ".kg.tidb_impl",
# "TiDBVectorDBStorage": ".kg.tidb_impl",

View file

@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted
"""
try:
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
self._collection.delete(ids=ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"

View file

@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data}
@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
# Get current time as UNIX timestamp
import time
@ -522,11 +520,6 @@ class TiDBVectorDBStorage(BaseVectorStorage):
}
await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param)
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs from the storage.

View file

@ -17,14 +17,13 @@ from .shared_storage import (
set_all_update_flags,
)
import faiss # type: ignore
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
if not pm.is_installed(FAISS_PACKAGE):
pm.install(FAISS_PACKAGE)
import faiss # type: ignore
@final
@dataclass

View file

@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage):
return
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
if "chunks_list" not in doc_data:
doc_data["chunks_list"] = []
self._data.update(data)
await set_all_update_flags(self.namespace)

View file

@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
self._data.update(loaded_data)
# Calculate data count based on namespace
if self.namespace.endswith("cache"):
# For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in loaded_data.values()
if isinstance(first_level_dict, dict)
# Migrate legacy cache structure if needed
if self.namespace.endswith("_cache"):
loaded_data = await self._migrate_legacy_cache_structure(
loaded_data
)
else:
# For non-cache namespaces, use the original count method
data_count = len(loaded_data)
self._data.update(loaded_data)
data_count = len(loaded_data)
logger.info(
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
# Calculate data count based on namespace
if self.namespace.endswith("cache"):
# # For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in data_dict.values()
if isinstance(first_level_dict, dict)
)
else:
# For non-cache namespaces, use the original count method
data_count = len(data_dict)
# Calculate data count - all data is now flattened
data_count = len(data_dict)
logger.debug(
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
@ -92,22 +78,49 @@ class JsonKVStorage(BaseKVStorage):
Dictionary containing all stored data
"""
async with self._storage_lock:
return dict(self._data)
result = {}
for key, value in self._data.items():
if value:
# Create a copy to avoid modifying the original data
data = dict(value)
# Ensure time fields are present, provide default values for old data
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
result[key] = data
else:
result[key] = value
return result
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock:
return self._data.get(id)
result = self._data.get(id)
if result:
# Create a copy to avoid modifying the original data
result = dict(result)
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
return result
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async with self._storage_lock:
return [
(
{k: v for k, v in self._data[id].items()}
if self._data.get(id, None)
else None
)
for id in ids
]
results = []
for id in ids:
data = self._data.get(id, None)
if data:
# Create a copy to avoid modifying the original data
result = {k: v for k, v in data.items()}
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
results.append(result)
else:
results.append(None)
return results
async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._storage_lock:
@ -121,8 +134,29 @@ class JsonKVStorage(BaseKVStorage):
"""
if not data:
return
import time
current_time = int(time.time()) # Get current Unix timestamp
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock:
# Add timestamps to data based on whether key exists
for k, v in data.items():
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Add timestamps based on whether key exists
if k in self._data: # Key exists, only update update_time
v["update_time"] = current_time
else: # New key, set both create_time and update_time
v["create_time"] = current_time
v["update_time"] = current_time
v["_id"] = k
self._data.update(data)
await set_all_update_flags(self.namespace)
@ -150,14 +184,14 @@ class JsonKVStorage(BaseKVStorage):
await set_all_update_flags(self.namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
"""Delete specific records from storage by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of cache mode to be drop from storage
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
@ -167,9 +201,29 @@ class JsonKVStorage(BaseKVStorage):
return False
try:
await self.delete(modes)
async with self._storage_lock:
keys_to_delete = []
modes_set = set(modes) # Convert to set for efficient lookup
for key in list(self._data.keys()):
# Parse flattened cache key: mode:cache_type:hash
parts = key.split(":", 2)
if len(parts) == 3 and parts[0] in modes_set:
keys_to_delete.append(key)
# Batch delete
for key in keys_to_delete:
self._data.pop(key, None)
if keys_to_delete:
await set_all_update_flags(self.namespace)
logger.info(
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
)
return True
except Exception:
except Exception as e:
logger.error(f"Error dropping cache by modes: {e}")
return False
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
@ -245,9 +299,58 @@ class JsonKVStorage(BaseKVStorage):
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
"""Migrate legacy nested cache structure to flattened structure
Args:
data: Original data dictionary that may contain legacy structure
Returns:
Migrated data dictionary with flattened cache keys
"""
from lightrag.utils import generate_cache_key
# Early return if data is empty
if not data:
return data
# Check first entry to see if it's already in new format
first_key = next(iter(data.keys()))
if ":" in first_key and len(first_key.split(":")) == 3:
# Already in flattened format, return as-is
return data
migrated_data = {}
migration_count = 0
for key, value in data.items():
# Check if this is a legacy nested cache structure
if isinstance(value, dict) and all(
isinstance(v, dict) and "return" in v for v in value.values()
):
# This looks like a legacy cache mode with nested structure
mode = key
for cache_hash, cache_entry in value.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
migrated_data[flattened_key] = cache_entry
migration_count += 1
else:
# Keep non-cache data or already flattened cache data as-is
migrated_data[key] = value
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure"
)
# Persist migrated data immediately
write_json(migrated_data, self._file_name)
return migrated_data
async def finalize(self):
"""Finalize storage resources
Persistence cache data to disk before exiting
"""
if self.namespace.endswith("cache"):
if self.namespace.endswith("_cache"):
await self.index_done_callback()

View file

@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
import configparser
from pymilvus import MilvusClient # type: ignore
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@ -24,16 +24,605 @@ config.read("config.ini", "utf-8")
@final
@dataclass
class MilvusVectorDBStorage(BaseVectorStorage):
@staticmethod
def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs
):
if client.has_collection(collection_name):
return
client.create_collection(
collection_name, max_length=64, id_type="string", **kwargs
def _create_schema_for_namespace(self) -> CollectionSchema:
"""Create schema based on the current instance's namespace"""
# Get vector dimension from embedding_func
dimension = self.embedding_func.embedding_dim
# Base fields (common to all collections)
base_fields = [
FieldSchema(
name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="created_at", dtype=DataType.INT64),
]
# Determine specific fields based on namespace
if "entities" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="entity_name",
dtype=DataType.VARCHAR,
max_length=256,
nullable=True,
),
FieldSchema(
name="entity_type",
dtype=DataType.VARCHAR,
max_length=64,
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG entities vector storage"
elif "relationships" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
),
FieldSchema(
name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
),
FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG relationships vector storage"
elif "chunks" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="full_doc_id",
dtype=DataType.VARCHAR,
max_length=64,
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG chunks vector storage"
else:
# Default generic schema (backward compatibility)
specific_fields = [
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG generic vector storage"
# Merge all fields
all_fields = base_fields + specific_fields
return CollectionSchema(
fields=all_fields,
description=description,
enable_dynamic_field=True, # Support dynamic fields
)
def _get_index_params(self):
"""Get IndexParams in a version-compatible way"""
try:
# Try to use client's prepare_index_params method (most common)
if hasattr(self._client, "prepare_index_params"):
return self._client.prepare_index_params()
except Exception:
pass
try:
# Try to import IndexParams from different possible locations
from pymilvus.client.prepare import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus.client.types import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus import IndexParams
return IndexParams()
except ImportError:
pass
# If all else fails, return None to use fallback method
return None
def _create_vector_index_fallback(self):
"""Fallback method to create vector index using direct API"""
try:
self._client.create_index(
collection_name=self.namespace,
field_name="vector",
index_params={
"index_type": "HNSW",
"metric_type": "COSINE",
"params": {"M": 16, "efConstruction": 256},
},
)
logger.debug("Created vector index using fallback method")
except Exception as e:
logger.warning(f"Failed to create vector index using fallback method: {e}")
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
"""Fallback method to create scalar index using direct API"""
# Skip unsupported index types
if index_type == "SORTED":
logger.info(
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
)
return
try:
self._client.create_index(
collection_name=self.namespace,
field_name=field_name,
index_params={"index_type": index_type},
)
logger.debug(f"Created {field_name} index using fallback method")
except Exception as e:
logger.info(
f"Could not create {field_name} index using fallback method: {e}"
)
def _create_indexes_after_collection(self):
"""Create indexes after collection is created"""
try:
# Try to get IndexParams in a version-compatible way
IndexParamsClass = self._get_index_params()
if IndexParamsClass is not None:
# Use IndexParams approach if available
try:
# Create vector index first (required for most operations)
vector_index = IndexParamsClass
vector_index.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 256},
)
self._client.create_index(
collection_name=self.namespace, index_params=vector_index
)
logger.debug("Created vector index using IndexParams")
except Exception as e:
logger.debug(f"IndexParams method failed for vector index: {e}")
self._create_vector_index_fallback()
# Create scalar indexes based on namespace
if "entities" in self.namespace.lower():
# Create indexes for entity fields
try:
entity_name_index = self._get_index_params()
entity_name_index.add_index(
field_name="entity_name", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace,
index_params=entity_name_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for entity_name: {e}")
self._create_scalar_index_fallback("entity_name", "INVERTED")
try:
entity_type_index = self._get_index_params()
entity_type_index.add_index(
field_name="entity_type", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace,
index_params=entity_type_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for entity_type: {e}")
self._create_scalar_index_fallback("entity_type", "INVERTED")
elif "relationships" in self.namespace.lower():
# Create indexes for relationship fields
try:
src_id_index = self._get_index_params()
src_id_index.add_index(
field_name="src_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=src_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for src_id: {e}")
self._create_scalar_index_fallback("src_id", "INVERTED")
try:
tgt_id_index = self._get_index_params()
tgt_id_index.add_index(
field_name="tgt_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=tgt_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for tgt_id: {e}")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
# Create indexes for chunk fields
try:
doc_id_index = self._get_index_params()
doc_id_index.add_index(
field_name="full_doc_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=doc_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for full_doc_id: {e}")
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
# No common indexes needed
else:
# Fallback to direct API calls if IndexParams is not available
logger.info(
f"IndexParams not available, using fallback methods for {self.namespace}"
)
# Create vector index using fallback
self._create_vector_index_fallback()
# Create scalar indexes using fallback
if "entities" in self.namespace.lower():
self._create_scalar_index_fallback("entity_name", "INVERTED")
self._create_scalar_index_fallback("entity_type", "INVERTED")
elif "relationships" in self.namespace.lower():
self._create_scalar_index_fallback("src_id", "INVERTED")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
logger.info(f"Created indexes for collection: {self.namespace}")
except Exception as e:
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
def _get_required_fields_for_namespace(self) -> dict:
"""Get required core field definitions for current namespace"""
# Base fields (common to all types)
base_fields = {
"id": {"type": "VarChar", "is_primary": True},
"vector": {"type": "FloatVector"},
"created_at": {"type": "Int64"},
}
# Add specific fields based on namespace
if "entities" in self.namespace.lower():
specific_fields = {
"entity_name": {"type": "VarChar"},
"entity_type": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
elif "relationships" in self.namespace.lower():
specific_fields = {
"src_id": {"type": "VarChar"},
"tgt_id": {"type": "VarChar"},
"weight": {"type": "Double"},
"file_path": {"type": "VarChar"},
}
elif "chunks" in self.namespace.lower():
specific_fields = {
"full_doc_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
else:
specific_fields = {
"file_path": {"type": "VarChar"},
}
return {**base_fields, **specific_fields}
def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
"""Check compatibility of a single field"""
field_name = existing_field.get("name", "unknown")
existing_type = existing_field.get("type")
expected_type = expected_config.get("type")
logger.debug(
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
)
# Convert DataType enum values to string names if needed
original_existing_type = existing_type
if hasattr(existing_type, "name"):
existing_type = existing_type.name
logger.debug(
f"Converted enum to name: {original_existing_type} -> {existing_type}"
)
elif isinstance(existing_type, int):
# Map common Milvus internal type codes to type names for backward compatibility
type_mapping = {
21: "VarChar",
101: "FloatVector",
5: "Int64",
9: "Double",
}
mapped_type = type_mapping.get(existing_type, str(existing_type))
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
existing_type = mapped_type
# Normalize type names for comparison
type_aliases = {
"VARCHAR": "VarChar",
"String": "VarChar",
"FLOAT_VECTOR": "FloatVector",
"INT64": "Int64",
"BigInt": "Int64",
"DOUBLE": "Double",
"Float": "Double",
}
original_existing = existing_type
original_expected = expected_type
existing_type = type_aliases.get(existing_type, existing_type)
expected_type = type_aliases.get(expected_type, expected_type)
if original_existing != existing_type or original_expected != expected_type:
logger.debug(
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
)
# Basic type compatibility check
type_compatible = existing_type == expected_type
logger.debug(
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
)
if not type_compatible:
logger.warning(
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
)
return False
# Primary key check - be more flexible about primary key detection
if expected_config.get("is_primary"):
# Check multiple possible field names for primary key status
is_primary = (
existing_field.get("is_primary_key", False)
or existing_field.get("is_primary", False)
or existing_field.get("primary_key", False)
)
logger.debug(
f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
)
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
# For ID field, be more lenient - if it's the ID field, assume it should be primary
if field_name == "id" and not is_primary:
logger.info(
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
)
# Don't fail for ID field primary key mismatch
elif not is_primary:
logger.warning(
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
)
return False
logger.debug(f"Field '{field_name}' is compatible")
return True
def _check_vector_dimension(self, collection_info: dict):
"""Check vector dimension compatibility"""
current_dimension = self.embedding_func.embedding_dim
# Find vector field dimension
for field in collection_info.get("fields", []):
if field.get("name") == "vector":
field_type = field.get("type")
if field_type in ["FloatVector", "FLOAT_VECTOR"]:
existing_dimension = field.get("params", {}).get("dim")
if existing_dimension != current_dimension:
raise ValueError(
f"Vector dimension mismatch for collection '{self.namespace}': "
f"existing={existing_dimension}, current={current_dimension}"
)
logger.debug(f"Vector dimension check passed: {current_dimension}")
return
# If no vector field found, this might be an old collection created with simple schema
logger.warning(
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
)
logger.warning("Consider recreating the collection for optimal performance.")
return
def _check_schema_compatibility(self, collection_info: dict):
"""Check schema field compatibility"""
existing_fields = {
field["name"]: field for field in collection_info.get("fields", [])
}
# Check if this is an old collection created with simple schema
has_vector_field = any(
field.get("name") == "vector" for field in collection_info.get("fields", [])
)
if not has_vector_field:
logger.warning(
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
)
logger.warning(
"This collection will work but may have suboptimal performance"
)
logger.warning("Consider recreating the collection for optimal performance")
return
# For collections with vector field, check basic compatibility
# Only check for critical incompatibilities, not missing optional fields
critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
incompatible_fields = []
for field_name, expected_config in critical_fields.items():
if field_name in existing_fields:
existing_field = existing_fields[field_name]
if not self._is_field_compatible(existing_field, expected_config):
incompatible_fields.append(
f"{field_name}: expected {expected_config['type']}, "
f"got {existing_field.get('type')}"
)
if incompatible_fields:
raise ValueError(
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
)
# Get all expected fields for informational purposes
expected_fields = self._get_required_fields_for_namespace()
missing_fields = [
field for field in expected_fields if field not in existing_fields
]
if missing_fields:
logger.info(
f"Collection {self.namespace} missing optional fields: {missing_fields}"
)
logger.info(
"These fields would be available in a newly created collection for better performance"
)
logger.debug(f"Schema compatibility check passed for {self.namespace}")
def _validate_collection_compatibility(self):
"""Validate existing collection's dimension and schema compatibility"""
try:
collection_info = self._client.describe_collection(self.namespace)
# 1. Check vector dimension
self._check_vector_dimension(collection_info)
# 2. Check schema compatibility
self._check_schema_compatibility(collection_info)
logger.info(f"Collection {self.namespace} compatibility validation passed")
except Exception as e:
logger.error(
f"Collection compatibility validation failed for {self.namespace}: {e}"
)
raise
def _create_collection_if_not_exist(self):
"""Create collection if not exists and check existing collection compatibility"""
try:
# First, list all collections to see what actually exists
try:
all_collections = self._client.list_collections()
logger.debug(f"All collections in database: {all_collections}")
except Exception as list_error:
logger.warning(f"Could not list collections: {list_error}")
all_collections = []
# Check if our specific collection exists
collection_exists = self._client.has_collection(self.namespace)
logger.info(
f"Collection '{self.namespace}' exists check: {collection_exists}"
)
if collection_exists:
# Double-check by trying to describe the collection
try:
self._client.describe_collection(self.namespace)
logger.info(
f"Collection '{self.namespace}' confirmed to exist, validating compatibility..."
)
self._validate_collection_compatibility()
return
except Exception as describe_error:
logger.warning(
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
)
logger.info(
"Treating as if collection doesn't exist and creating new one..."
)
# Fall through to creation logic
# Collection doesn't exist, create new collection
logger.info(f"Creating new collection: {self.namespace}")
schema = self._create_schema_for_namespace()
# Create collection with schema only first
self._client.create_collection(
collection_name=self.namespace, schema=schema
)
# Then create indexes
self._create_indexes_after_collection()
logger.info(f"Successfully created Milvus collection: {self.namespace}")
except Exception as e:
logger.error(
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
)
# If there's any error, try to force create the collection
logger.info(f"Attempting to force create collection {self.namespace}...")
try:
# Try to drop the collection first if it exists in a bad state
try:
if self._client.has_collection(self.namespace):
logger.info(
f"Dropping potentially corrupted collection {self.namespace}"
)
self._client.drop_collection(self.namespace)
except Exception as drop_error:
logger.warning(
f"Could not drop collection {self.namespace}: {drop_error}"
)
# Create fresh collection
schema = self._create_schema_for_namespace()
self._client.create_collection(
collection_name=self.namespace, schema=schema
)
self._create_indexes_after_collection()
logger.info(f"Successfully force-created collection {self.namespace}")
except Exception as create_error:
logger.error(
f"Failed to force-create collection {self.namespace}: {create_error}"
)
raise
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
# Ensure created_at is in meta_fields
if "created_at" not in self.meta_fields:
self.meta_fields.add("created_at")
self._client = MilvusClient(
uri=os.environ.get(
"MILVUS_URI",
@ -68,14 +661,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
),
)
self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
# Create collection and check compatibility
self._create_collection_if_not_exist()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields)
results = self._client.search(
collection_name=self.namespace,
data=embedding,
limit=top_k,
output_fields=list(self.meta_fields) + ["created_at"],
output_fields=output_fields,
search_params={
"metric_type": "COSINE",
"params": {"radius": self.cosine_better_than_threshold},
},
)
print(results)
return [
{
**dp["entity"],
"id": dp["id"],
"distance": dp["distance"],
# created_at is requested in output_fields, so it should be a top-level key in the result dict (dp)
"created_at": dp.get("created_at"),
}
for dp in results[0]
@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
The vector data if found, or None if not found
"""
try:
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
# Query Milvus for a specific ID
result = self._client.query(
collection_name=self.namespace,
filter=f'id == "{id}"',
output_fields=list(self.meta_fields) + ["id", "created_at"],
output_fields=output_fields,
)
if not result or len(result) == 0:
return None
# Ensure the result contains created_at field
if "created_at" not in result[0]:
result[0]["created_at"] = None
return result[0]
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return []
try:
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
# Prepare the ID filter expression
id_list = '", "'.join(ids)
filter_expr = f'id in ["{id_list}"]'
@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
result = self._client.query(
collection_name=self.namespace,
filter=filter_expr,
output_fields=list(self.meta_fields) + ["id", "created_at"],
output_fields=output_fields,
)
# Ensure each result contains created_at field
for item in result:
if "created_at" not in item:
item["created_at"] = None
return result or []
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._client.drop_collection(self.namespace)
# Recreate the collection
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
self._create_collection_if_not_exist()
logger.info(
f"Process {os.getpid()} drop Milvus collection {self.namespace}"

View file

@ -1,4 +1,5 @@
import os
import time
from dataclasses import dataclass, field
import numpy as np
import configparser
@ -14,7 +15,6 @@ from ..base import (
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
@ -35,6 +35,7 @@ config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
class ClientManager:
@ -96,11 +97,22 @@ class MongoKVStorage(BaseKVStorage):
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id})
# Unified handling for flattened keys
doc = await self._data.find_one({"_id": id})
if doc:
# Ensure time fields are present, provide default values for old data
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
return doc
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list()
docs = await cursor.to_list()
# Ensure time fields are present for all documents
for doc in docs:
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
return docs
async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
@ -117,47 +129,53 @@ class MongoKVStorage(BaseKVStorage):
result = {}
async for doc in cursor:
doc_id = doc.pop("_id")
# Ensure time fields are present for all documents
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
result[doc_id] = doc
return result
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
update_tasks: list[Any] = []
for mode, items in data.items():
for k, v in items.items():
key = f"{mode}_{k}"
data[mode][k]["_id"] = f"{mode}_{k}"
update_tasks.append(
self._data.update_one(
{"_id": key}, {"$setOnInsert": v}, upsert=True
)
)
await asyncio.gather(*update_tasks)
else:
update_tasks = []
for k, v in data.items():
data[k]["_id"] = k
update_tasks.append(
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
)
await asyncio.gather(*update_tasks)
# Unified handling for all namespaces with flattened keys
# Use bulk_write for better performance
from pymongo import UpdateOne
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
res = {}
v = await self._data.find_one({"_id": mode + "_" + id})
if v:
res[id] = v
logger.debug(f"llm_response_cache find one by:{id}")
return res
else:
return None
else:
return None
operations = []
current_time = int(time.time()) # Get current Unix timestamp
for k, v in data.items():
# For text_chunks namespace, ensure llm_cache_list field exists
if self.namespace.endswith("text_chunks"):
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Create a copy of v for $set operation, excluding create_time to avoid conflicts
v_for_set = v.copy()
v_for_set["_id"] = k # Use flattened key as _id
v_for_set["update_time"] = current_time # Always update update_time
# Remove create_time from $set to avoid conflict with $setOnInsert
v_for_set.pop("create_time", None)
operations.append(
UpdateOne(
{"_id": k},
{
"$set": v_for_set, # Update all fields except create_time
"$setOnInsert": {
"create_time": current_time
}, # Set create_time only on insert
},
upsert=True,
)
)
if operations:
await self._data.bulk_write(operations)
async def index_done_callback(self) -> None:
# Mongo handles persistence automatically
@ -197,8 +215,8 @@ class MongoKVStorage(BaseKVStorage):
return False
try:
# Build regex pattern to match documents with the specified modes
pattern = f"^({'|'.join(modes)})_"
# Build regex pattern to match flattened key format: mode:cache_type:hash
pattern = f"^({'|'.join(modes)}):"
result = await self._data.delete_many({"_id": {"$regex": pattern}})
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
return True
@ -262,11 +280,14 @@ class MongoDocStatusStorage(DocStatusStorage):
return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
update_tasks: list[Any] = []
for k, v in data.items():
# Ensure chunks_list field exists and is an array
if "chunks_list" not in v:
v["chunks_list"] = []
data[k]["_id"] = k
update_tasks.append(
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
@ -299,6 +320,7 @@ class MongoDocStatusStorage(DocStatusStorage):
updated_at=doc.get("updated_at"),
chunks_count=doc.get("chunks_count", -1),
file_path=doc.get("file_path", doc["_id"]),
chunks_list=doc.get("chunks_list", []),
)
for doc in result
}
@ -417,11 +439,21 @@ class MongoGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if there's a direct single-hop edge from source_node_id to target_node_id.
Check if there's a direct single-hop edge between source_node_id and target_node_id.
"""
# Direct check if the target_node appears among the edges array.
doc = await self.edge_collection.find_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id},
{
"$or": [
{
"source_node_id": source_node_id,
"target_node_id": target_node_id,
},
{
"source_node_id": target_node_id,
"target_node_id": source_node_id,
},
]
},
{"_id": 1},
)
return doc is not None
@ -651,7 +683,7 @@ class MongoGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
Upsert an edge between source_node_id and target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data.
"""
# Ensure source node exists
@ -663,8 +695,22 @@ class MongoGraphStorage(BaseGraphStorage):
GRAPH_FIELD_SEP
)
edge_data["source_node_id"] = source_node_id
edge_data["target_node_id"] = target_node_id
await self.edge_collection.update_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id},
{
"$or": [
{
"source_node_id": source_node_id,
"target_node_id": target_node_id,
},
{
"source_node_id": target_node_id,
"target_node_id": source_node_id,
},
]
},
update_doc,
upsert=True,
)
@ -678,7 +724,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None:
"""
1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id.
2) Remove inbound & outbound edges from any doc that references node_id.
"""
# Remove all edges
await self.edge_collection.delete_many(
@ -709,141 +755,369 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"])
return labels
def _construct_graph_node(
self, node_id, node_data: dict[str, str]
) -> KnowledgeGraphNode:
return KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in node_data.items()
if k
not in [
"_id",
"connected_edges",
"source_ids",
"edge_count",
]
},
)
def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]):
return KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in [
"_id",
"source_node_id",
"target_node_id",
"relationship",
"source_ids",
]
},
)
async def get_knowledge_graph_all_by_degree(
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
"""
It's possible that the node with one or multiple relationships is retrieved,
while its neighbor is not. Then this node might seem like disconnected in UI.
"""
total_node_count = await self.collection.count_documents({})
result = KnowledgeGraph()
seen_edges = set()
result.is_truncated = total_node_count > max_nodes
if result.is_truncated:
# Get all node_ids ranked by degree if max_nodes exceeds total node count
pipeline = [
{"$project": {"source_node_id": 1, "_id": 0}},
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
{
"$unionWith": {
"coll": self._edge_collection_name,
"pipeline": [
{"$project": {"target_node_id": 1, "_id": 0}},
{
"$group": {
"_id": "$target_node_id",
"degree": {"$sum": 1},
}
},
],
}
},
{"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
{"$sort": {"degree": -1}},
{"$limit": max_nodes},
]
cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
node_ids = []
async for doc in cursor:
node_id = str(doc["_id"])
node_ids.append(node_id)
cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
# As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
edge_cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": node_ids}},
{"target_node_id": {"$in": node_ids}},
]
}
)
else:
# All nodes and edges are needed
cursor = self.collection.find({}, {"source_ids": 0})
async for doc in cursor:
node_id = str(doc["_id"])
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
edge_cursor = self.edge_collection.find({})
async for edge in edge_cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
seen_edges.add(edge_id)
result.edges.append(self._construct_graph_edge(edge_id, edge))
return result
async def _bidirectional_bfs_nodes(
self,
node_labels: list[str],
seen_nodes: set[str],
result: KnowledgeGraph,
depth: int = 0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
if depth > max_depth or len(result.nodes) > max_nodes:
return result
cursor = self.collection.find({"_id": {"$in": node_labels}})
async for node in cursor:
node_id = node["_id"]
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(self._construct_graph_node(node_id, node))
if len(result.nodes) > max_nodes:
return result
# Collect neighbors
# Get both inbound and outbound one hop nodes
cursor = self.edge_collection.find(
{
"$or": [
{"source_node_id": {"$in": node_labels}},
{"target_node_id": {"$in": node_labels}},
]
}
)
neighbor_nodes = []
async for edge in cursor:
if edge["source_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["source_node_id"])
if edge["target_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["target_node_id"])
if neighbor_nodes:
result = await self._bidirectional_bfs_nodes(
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
)
return result
async def get_knowledge_subgraph_bidirectional_bfs(
self,
node_label: str,
depth=0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
result = await self._bidirectional_bfs_nodes(
[node_label], seen_nodes, result, depth, max_depth, max_nodes
)
# Get all edges from seen_nodes
all_node_ids = list(seen_nodes)
cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": all_node_ids}},
{"target_node_id": {"$in": all_node_ids}},
]
}
)
async for edge in cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_subgraph_in_out_bound_bfs(
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
project_doc = {
"source_ids": 0,
"created_at": 0,
"entity_type": 0,
"file_path": 0,
}
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": node_label})
if not start_node:
logger.warning(f"Starting node with label {node_label} does not exist!")
return result
seen_nodes.add(node_label)
result.nodes.append(self._construct_graph_node(node_label, start_node))
if max_depth == 0:
return result
# In MongoDB, depth = 0 means one-hop
max_depth = max_depth - 1
pipeline = [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{
"$unionWith": {
"coll": self._collection_name,
"pipeline": [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "source_node_id",
"connectToField": "target_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
}
},
],
}
},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
node_edges = []
# Two records for node_label are returned capturing outbound and inbound connected_edges
async for doc in cursor:
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
# Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
node_edges = sorted(
node_edges,
key=lambda x: (x["depth"], -x["weight"]),
)
# As order matters, we need to use another list to store the node_id
# And only take the first max_nodes ones
node_ids = []
for edge in node_edges:
if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"])
if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
# Filter out all the node whose id is same as node_label so that we do not check existence next step
cursor = self.collection.find({"_id": {"$in": node_ids}})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 5,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
If a graph is like this and starting from B:
A B C F, B -> E, C D
Outbound BFS:
B E
Inbound BFS:
A B
C B
F C
Bidirectional BFS:
A B
B E
F C
C B
C D
"""
label = node_label
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
node_edges = []
start = time.perf_counter()
try:
# Optimize pipeline to avoid memory issues with large datasets
if label == "*":
# For getting all nodes, use a simpler pipeline to avoid memory issues
pipeline = [
{"$limit": max_nodes}, # Limit early to reduce memory usage
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
]
# Check if we need to set truncation flag
all_node_count = await self.collection.count_documents({})
result.is_truncated = all_node_count > max_nodes
else:
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": label})
if not start_node:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# For specific node queries, use the original pipeline but optimized
pipeline = [
{"$match": {"_id": label}},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
{"$sort": {"edge_count": -1}},
{"$limit": max_nodes},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
nodes_processed = 0
async for doc in cursor:
# Add the start node
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
if node_label == "*":
result = await self.get_knowledge_graph_all_by_degree(
max_depth, max_nodes
)
elif GRAPH_BFS_MODE == "in_out_bound":
result = await self.get_knowledge_subgraph_in_out_bound_bfs(
node_label, max_depth, max_nodes
)
else:
result = await self.get_knowledge_subgraph_bidirectional_bfs(
node_label, 0, max_depth, max_nodes
)
seen_nodes.add(node_id)
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
nodes_processed += 1
# Additional safety check to prevent memory issues
if nodes_processed >= max_nodes:
result.is_truncated = True
break
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in [
"_id",
"source_node_id",
"target_node_id",
"relationship",
]
},
)
)
seen_edges.add(edge_id)
duration = time.perf_counter() - start
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
)
except PyMongoError as e:
@ -856,13 +1130,8 @@ class MongoGraphStorage(BaseGraphStorage):
try:
simple_cursor = self.collection.find({}).limit(max_nodes)
async for doc in simple_cursor:
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={k: v for k, v in doc.items() if k != "_id"},
)
self._construct_graph_node(str(doc["_id"]), doc)
)
result.is_truncated = True
logger.info(
@ -1023,13 +1292,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
# Add current time as Unix timestamp
import time
current_time = int(time.time())
list_data = [
@ -1114,7 +1381,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
Args:
ids: List of vector IDs to be deleted
"""
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
if not ids:
return

View file

@ -106,7 +106,9 @@ class NetworkXStorage(BaseGraphStorage):
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
graph = await self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id)
src_degree = graph.degree(src_id) if graph.has_node(src_id) else 0
tgt_degree = graph.degree(tgt_id) if graph.has_node(tgt_id) else 0
return src_degree + tgt_degree
async def get_edge(
self, source_node_id: str, target_node_id: str

View file

@ -136,6 +136,52 @@ class PostgreSQLDB:
except Exception as e:
logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}")
async def _migrate_llm_cache_add_cache_type(self):
"""Add cache_type column to LIGHTRAG_LLM_CACHE table if it doesn't exist"""
try:
# Check if cache_type column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'lightrag_llm_cache'
AND column_name = 'cache_type'
"""
column_info = await self.query(check_column_sql)
if not column_info:
logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table")
add_column_sql = """
ALTER TABLE LIGHTRAG_LLM_CACHE
ADD COLUMN cache_type VARCHAR(32) NULL
"""
await self.execute(add_column_sql)
logger.info(
"Successfully added cache_type column to LIGHTRAG_LLM_CACHE table"
)
# Migrate existing data: extract cache_type from flattened keys
logger.info(
"Migrating existing LLM cache data to populate cache_type field"
)
update_sql = """
UPDATE LIGHTRAG_LLM_CACHE
SET cache_type = CASE
WHEN id LIKE '%:%:%' THEN split_part(id, ':', 2)
ELSE 'extract'
END
WHERE cache_type IS NULL
"""
await self.execute(update_sql)
logger.info("Successfully migrated existing LLM cache data")
else:
logger.info(
"cache_type column already exists in LIGHTRAG_LLM_CACHE table"
)
except Exception as e:
logger.warning(
f"Failed to add cache_type column to LIGHTRAG_LLM_CACHE: {e}"
)
async def _migrate_timestamp_columns(self):
"""Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time"""
# Tables and columns that need migration
@ -189,6 +235,239 @@ class PostgreSQLDB:
# Log error but don't interrupt the process
logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}")
async def _migrate_doc_chunks_to_vdb_chunks(self):
"""
Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met.
This migration is intended for users who are upgrading and have an older table structure
where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column.
"""
try:
# 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty
vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS"
vdb_chunks_count_result = await self.query(vdb_chunks_count_sql)
if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0:
logger.info(
"Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data."
)
return
# 2. Check if `content_vector` column exists in the old table
check_column_sql = """
SELECT 1 FROM information_schema.columns
WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector'
"""
column_exists = await self.query(check_column_sql)
if not column_exists:
logger.info(
"Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS"
)
return
# 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data
doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS"
doc_chunks_count_result = await self.query(doc_chunks_count_sql)
if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0:
logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.")
return
# 4. Perform the migration
logger.info(
"Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..."
)
migration_sql = """
INSERT INTO LIGHTRAG_VDB_CHUNKS (
id, workspace, full_doc_id, chunk_order_index, tokens, content,
content_vector, file_path, create_time, update_time
)
SELECT
id, workspace, full_doc_id, chunk_order_index, tokens, content,
content_vector, file_path, create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS
ON CONFLICT (workspace, id) DO NOTHING;
"""
await self.execute(migration_sql)
logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.")
except Exception as e:
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
# Do not re-raise, to allow the application to start
async def _check_llm_cache_needs_migration(self):
"""Check if LLM cache data needs migration by examining the first record"""
try:
# Only query the first record to determine format
check_sql = """
SELECT id FROM LIGHTRAG_LLM_CACHE
ORDER BY create_time ASC
LIMIT 1
"""
result = await self.query(check_sql)
if result and result.get("id"):
# If id doesn't contain colon, it's old format
return ":" not in result["id"]
return False # No data or already new format
except Exception as e:
logger.warning(f"Failed to check LLM cache migration status: {e}")
return False
async def _migrate_llm_cache_to_flattened_keys(self):
"""Migrate LLM cache to flattened key format, recalculating hash values"""
try:
# Get all old format data
old_data_sql = """
SELECT id, mode, original_prompt, return_value, chunk_id,
create_time, update_time
FROM LIGHTRAG_LLM_CACHE
WHERE id NOT LIKE '%:%'
"""
old_records = await self.query(old_data_sql, multirows=True)
if not old_records:
logger.info("No old format LLM cache data found, skipping migration")
return
logger.info(
f"Found {len(old_records)} old format cache records, starting migration..."
)
# Import hash calculation function
from ..utils import compute_args_hash
migrated_count = 0
# Migrate data in batches
for record in old_records:
try:
# Recalculate hash using correct method
new_hash = compute_args_hash(
record["mode"], record["original_prompt"]
)
# Determine cache_type based on mode
cache_type = "extract" if record["mode"] == "default" else "unknown"
# Generate new flattened key
new_key = f"{record['mode']}:{cache_type}:{new_hash}"
# Insert new format data with cache_type field
insert_sql = """
INSERT INTO LIGHTRAG_LLM_CACHE
(workspace, id, mode, original_prompt, return_value, chunk_id, cache_type, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (workspace, mode, id) DO NOTHING
"""
await self.execute(
insert_sql,
{
"workspace": self.workspace,
"id": new_key,
"mode": record["mode"],
"original_prompt": record["original_prompt"],
"return_value": record["return_value"],
"chunk_id": record["chunk_id"],
"cache_type": cache_type, # Add cache_type field
"create_time": record["create_time"],
"update_time": record["update_time"],
},
)
# Delete old data
delete_sql = """
DELETE FROM LIGHTRAG_LLM_CACHE
WHERE workspace=$1 AND mode=$2 AND id=$3
"""
await self.execute(
delete_sql,
{
"workspace": self.workspace,
"mode": record["mode"],
"id": record["id"], # Old id
},
)
migrated_count += 1
except Exception as e:
logger.warning(
f"Failed to migrate cache record {record['id']}: {e}"
)
continue
logger.info(
f"Successfully migrated {migrated_count} cache records to flattened format"
)
except Exception as e:
logger.error(f"LLM cache migration failed: {e}")
# Don't raise exception, allow system to continue startup
async def _migrate_doc_status_add_chunks_list(self):
"""Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist"""
try:
# Check if chunks_list column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'lightrag_doc_status'
AND column_name = 'chunks_list'
"""
column_info = await self.query(check_column_sql)
if not column_info:
logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table")
add_column_sql = """
ALTER TABLE LIGHTRAG_DOC_STATUS
ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb
"""
await self.execute(add_column_sql)
logger.info(
"Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table"
)
else:
logger.info(
"chunks_list column already exists in LIGHTRAG_DOC_STATUS table"
)
except Exception as e:
logger.warning(
f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}"
)
async def _migrate_text_chunks_add_llm_cache_list(self):
"""Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist"""
try:
# Check if llm_cache_list column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'lightrag_doc_chunks'
AND column_name = 'llm_cache_list'
"""
column_info = await self.query(check_column_sql)
if not column_info:
logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table")
add_column_sql = """
ALTER TABLE LIGHTRAG_DOC_CHUNKS
ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb
"""
await self.execute(add_column_sql)
logger.info(
"Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table"
)
else:
logger.info(
"llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table"
)
except Exception as e:
logger.warning(
f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}"
)
async def check_tables(self):
# First create all tables
for k, v in TABLES.items():
@ -240,6 +519,44 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}")
# Don't throw an exception, allow the initialization process to continue
# Migrate LLM cache table to add cache_type field if needed
try:
await self._migrate_llm_cache_add_cache_type()
except Exception as e:
logger.error(
f"PostgreSQL, Failed to migrate LLM cache cache_type field: {e}"
)
# Don't throw an exception, allow the initialization process to continue
# Finally, attempt to migrate old doc chunks data if needed
try:
await self._migrate_doc_chunks_to_vdb_chunks()
except Exception as e:
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
# Check and migrate LLM cache to flattened keys if needed
try:
if await self._check_llm_cache_needs_migration():
await self._migrate_llm_cache_to_flattened_keys()
except Exception as e:
logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
# Migrate doc status to add chunks_list field if needed
try:
await self._migrate_doc_status_add_chunks_list()
except Exception as e:
logger.error(
f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}"
)
# Migrate text chunks to add llm_cache_list field if needed
try:
await self._migrate_text_chunks_add_llm_cache_list()
except Exception as e:
logger.error(
f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}"
)
async def query(
self,
sql: str,
@ -423,74 +740,139 @@ class PGKVStorage(BaseKVStorage):
try:
results = await self.db.query(sql, params, multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
result_dict = {}
processed_results = {}
for row in results:
mode = row["mode"]
if mode not in result_dict:
result_dict[mode] = {}
result_dict[mode][row["id"]] = row
return result_dict
else:
return {row["id"]: row for row in results}
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
# Map field names and add cache_type for compatibility
processed_row = {
**row,
"return": row.get("return_value", ""),
"cache_type": row.get("original_prompt", "unknow"),
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
processed_results[row["id"]] = processed_row
return processed_results
# For text_chunks namespace, parse llm_cache_list JSON string back to list
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
processed_results = {}
for row in results:
llm_cache_list = row.get("llm_cache_list", [])
if isinstance(llm_cache_list, str):
try:
llm_cache_list = json.loads(llm_cache_list)
except json.JSONDecodeError:
llm_cache_list = []
row["llm_cache_list"] = llm_cache_list
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
row["create_time"] = create_time
row["update_time"] = (
create_time if update_time == 0 else update_time
)
processed_results[row["id"]] = row
return processed_results
# For other namespaces, return as-is
return {row["id"]: row for row in results}
except Exception as e:
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
return {}
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data by id."""
"""Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res if res else None
else:
response = await self.db.query(sql, params)
return response if response else None
response = await self.db.query(sql, params)
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, "mode": mode, "id": id}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res
else:
return None
if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list
llm_cache_list = response.get("llm_cache_list", [])
if isinstance(llm_cache_list, str):
try:
llm_cache_list = json.loads(llm_cache_list)
except json.JSONDecodeError:
llm_cache_list = []
response["llm_cache_list"] = llm_cache_list
create_time = response.get("create_time", 0)
update_time = response.get("update_time", 0)
response["create_time"] = create_time
response["update_time"] = create_time if update_time == 0 else update_time
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if response and is_namespace(
self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
):
create_time = response.get("create_time", 0)
update_time = response.get("update_time", 0)
# Map field names and add cache_type for compatibility
response = {
**response,
"return": response.get("return_value", ""),
"cache_type": response.get("cache_type"),
"original_prompt": response.get("original_prompt", ""),
"chunk_id": response.get("chunk_id"),
"mode": response.get("mode", "default"),
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
return response if response else None
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id"""
"""Get data by ids"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
modes = set()
dict_res: dict[str, dict] = {}
for row in array_res:
modes.add(row["mode"])
for mode in modes:
if mode not in dict_res:
dict_res[mode] = {}
for row in array_res:
dict_res[row["mode"]][row["id"]] = row
return [{k: v} for k, v in dict_res.items()]
else:
return await self.db.query(sql, params, multirows=True)
results = await self.db.query(sql, params, multirows=True)
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result
for result in results:
llm_cache_list = result.get("llm_cache_list", [])
if isinstance(llm_cache_list, str):
try:
llm_cache_list = json.loads(llm_cache_list)
except json.JSONDecodeError:
llm_cache_list = []
result["llm_cache_list"] = llm_cache_list
create_time = result.get("create_time", 0)
update_time = result.get("update_time", 0)
result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if results and is_namespace(
self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
):
processed_results = []
for row in results:
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
# Map field names and add cache_type for compatibility
processed_row = {
**row,
"return": row.get("return_value", ""),
"cache_type": row.get("cache_type"),
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
processed_results.append(processed_row)
return processed_results
return results if results else []
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
@ -520,7 +902,22 @@ class PGKVStorage(BaseKVStorage):
return
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass
current_time = datetime.datetime.now(timezone.utc)
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_text_chunk"]
_data = {
"workspace": self.db.workspace,
"id": k,
"tokens": v["tokens"],
"chunk_order_index": v["chunk_order_index"],
"full_doc_id": v["full_doc_id"],
"content": v["content"],
"file_path": v["file_path"],
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
"create_time": current_time,
"update_time": current_time,
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
@ -531,19 +928,21 @@ class PGKVStorage(BaseKVStorage):
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
"return_value": v["return"],
"mode": mode,
"chunk_id": v.get("chunk_id"),
}
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"id": k, # Use flattened key as id
"original_prompt": v["original_prompt"],
"return_value": v["return"],
"mode": v.get("mode", "default"), # Get mode from data
"chunk_id": v.get("chunk_id"),
"cache_type": v.get(
"cache_type", "extract"
), # Get cache_type from data
}
await self.db.execute(upsert_sql, _data)
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self) -> None:
# PG handles persistence automatically
@ -949,8 +1348,8 @@ class PGDocStatusStorage(DocStatusStorage):
else:
exist_keys = []
new_keys = set([s for s in keys if s not in exist_keys])
print(f"keys: {keys}")
print(f"new_keys: {new_keys}")
# print(f"keys: {keys}")
# print(f"new_keys: {new_keys}")
return new_keys
except Exception as e:
logger.error(
@ -965,6 +1364,14 @@ class PGDocStatusStorage(DocStatusStorage):
if result is None or result == []:
return None
else:
# Parse chunks_list JSON string back to list
chunks_list = result[0].get("chunks_list", [])
if isinstance(chunks_list, str):
try:
chunks_list = json.loads(chunks_list)
except json.JSONDecodeError:
chunks_list = []
return dict(
content=result[0]["content"],
content_length=result[0]["content_length"],
@ -974,6 +1381,7 @@ class PGDocStatusStorage(DocStatusStorage):
created_at=result[0]["created_at"],
updated_at=result[0]["updated_at"],
file_path=result[0]["file_path"],
chunks_list=chunks_list,
)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -988,19 +1396,32 @@ class PGDocStatusStorage(DocStatusStorage):
if not results:
return []
return [
{
"content": row["content"],
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"file_path": row["file_path"],
}
for row in results
]
processed_results = []
for row in results:
# Parse chunks_list JSON string back to list
chunks_list = row.get("chunks_list", [])
if isinstance(chunks_list, str):
try:
chunks_list = json.loads(chunks_list)
except json.JSONDecodeError:
chunks_list = []
processed_results.append(
{
"content": row["content"],
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"file_path": row["file_path"],
"chunks_list": chunks_list,
}
)
return processed_results
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
@ -1021,8 +1442,18 @@ class PGDocStatusStorage(DocStatusStorage):
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status.value}
result = await self.db.query(sql, params, True)
docs_by_status = {
element["id"]: DocProcessingStatus(
docs_by_status = {}
for element in result:
# Parse chunks_list JSON string back to list
chunks_list = element.get("chunks_list", [])
if isinstance(chunks_list, str):
try:
chunks_list = json.loads(chunks_list)
except json.JSONDecodeError:
chunks_list = []
docs_by_status[element["id"]] = DocProcessingStatus(
content=element["content"],
content_summary=element["content_summary"],
content_length=element["content_length"],
@ -1031,9 +1462,9 @@ class PGDocStatusStorage(DocStatusStorage):
updated_at=element["updated_at"],
chunks_count=element["chunks_count"],
file_path=element["file_path"],
chunks_list=chunks_list,
)
for element in result
}
return docs_by_status
async def index_done_callback(self) -> None:
@ -1097,10 +1528,10 @@ class PGDocStatusStorage(DocStatusStorage):
logger.warning(f"Unable to parse datetime string: {dt_str}")
return None
# Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations
# Both fields are updated from the input data in both INSERT and UPDATE cases
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at)
values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)
# Modified SQL to include created_at, updated_at, and chunks_list in both INSERT and UPDATE operations
# All fields are updated from the input data in both INSERT and UPDATE cases
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,chunks_list,created_at,updated_at)
values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
on conflict(id,workspace) do update set
content = EXCLUDED.content,
content_summary = EXCLUDED.content_summary,
@ -1108,6 +1539,7 @@ class PGDocStatusStorage(DocStatusStorage):
chunks_count = EXCLUDED.chunks_count,
status = EXCLUDED.status,
file_path = EXCLUDED.file_path,
chunks_list = EXCLUDED.chunks_list,
created_at = EXCLUDED.created_at,
updated_at = EXCLUDED.updated_at"""
for k, v in data.items():
@ -1115,7 +1547,7 @@ class PGDocStatusStorage(DocStatusStorage):
created_at = parse_datetime(v.get("created_at"))
updated_at = parse_datetime(v.get("updated_at"))
# chunks_count is optional
# chunks_count and chunks_list are optional
await self.db.execute(
sql,
{
@ -1127,6 +1559,7 @@ class PGDocStatusStorage(DocStatusStorage):
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"],
"file_path": v["file_path"],
"chunks_list": json.dumps(v.get("chunks_list", [])),
"created_at": created_at, # Use the converted datetime object
"updated_at": updated_at, # Use the converted datetime object
},
@ -2409,7 +2842,7 @@ class PGGraphStorage(BaseGraphStorage):
NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
@ -2438,6 +2871,21 @@ TABLES = {
},
"LIGHTRAG_DOC_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id VARCHAR(255),
workspace VARCHAR(255),
full_doc_id VARCHAR(256),
chunk_order_index INTEGER,
tokens INTEGER,
content TEXT,
file_path VARCHAR(256),
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
create_time TIMESTAMP(0) WITH TIME ZONE,
update_time TIMESTAMP(0) WITH TIME ZONE,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
"LIGHTRAG_VDB_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
id VARCHAR(255),
workspace VARCHAR(255),
full_doc_id VARCHAR(256),
@ -2448,7 +2896,7 @@ TABLES = {
file_path VARCHAR(256),
create_time TIMESTAMP(0) WITH TIME ZONE,
update_time TIMESTAMP(0) WITH TIME ZONE,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
"LIGHTRAG_VDB_ENTITY": {
@ -2503,6 +2951,7 @@ TABLES = {
chunks_count int4 NULL,
status varchar(64) NULL,
file_path TEXT NULL,
chunks_list JSONB NULL DEFAULT '[]'::jsonb,
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
@ -2517,24 +2966,30 @@ SQL_TEMPLATES = {
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
""",
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
create_time, update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
create_time, update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
@ -2542,16 +2997,31 @@ SQL_TEMPLATES = {
ON CONFLICT (workspace,id) DO UPDATE
SET content = $2, update_time = CURRENT_TIMESTAMP
""",
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id)
VALUES ($1, $2, $3, $4, $5, $6)
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (workspace,mode,id) DO UPDATE
SET original_prompt = EXCLUDED.original_prompt,
return_value=EXCLUDED.return_value,
mode=EXCLUDED.mode,
chunk_id=EXCLUDED.chunk_id,
cache_type=EXCLUDED.cache_type,
update_time = CURRENT_TIMESTAMP
""",
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (workspace,id) DO UPDATE
SET tokens=EXCLUDED.tokens,
chunk_order_index=EXCLUDED.chunk_order_index,
full_doc_id=EXCLUDED.full_doc_id,
content = EXCLUDED.content,
file_path=EXCLUDED.file_path,
llm_cache_list=EXCLUDED.llm_cache_list,
update_time = EXCLUDED.update_time
""",
# SQL for VectorStorage
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector, file_path,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@ -2564,7 +3034,6 @@ SQL_TEMPLATES = {
file_path=EXCLUDED.file_path,
update_time = EXCLUDED.update_time
""",
# SQL for VectorStorage
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
@ -2591,7 +3060,7 @@ SQL_TEMPLATES = {
"relationships": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
)
SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at
@ -2608,7 +3077,7 @@ SQL_TEMPLATES = {
"entities": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
)
SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
@ -2625,13 +3094,13 @@ SQL_TEMPLATES = {
"chunks": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
)
SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
(
SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_DOC_CHUNKS
FROM LIGHTRAG_VDB_CHUNKS
WHERE workspace=$1
AND id IN (SELECT chunk_id FROM relevant_chunks)
) as chunk_distances

View file

@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return

View file

@ -1,9 +1,10 @@
import os
from typing import Any, final
from typing import Any, final, Union
from dataclasses import dataclass
import pipmaster as pm
import configparser
from contextlib import asynccontextmanager
import threading
if not pm.is_installed("redis"):
pm.install("redis")
@ -13,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
from redis.exceptions import RedisError, ConnectionError # type: ignore
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
from lightrag.base import (
BaseKVStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
)
import json
@ -26,6 +32,41 @@ SOCKET_TIMEOUT = 5.0
SOCKET_CONNECT_TIMEOUT = 3.0
class RedisConnectionManager:
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
_pools = {}
_lock = threading.Lock()
@classmethod
def get_pool(cls, redis_url: str) -> ConnectionPool:
"""Get or create a connection pool for the given Redis URL"""
if redis_url not in cls._pools:
with cls._lock:
if redis_url not in cls._pools:
cls._pools[redis_url] = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
logger.info(f"Created shared Redis connection pool for {redis_url}")
return cls._pools[redis_url]
@classmethod
def close_all_pools(cls):
"""Close all connection pools (for cleanup)"""
with cls._lock:
for url, pool in cls._pools.items():
try:
pool.disconnect()
logger.info(f"Closed Redis connection pool for {url}")
except Exception as e:
logger.error(f"Error closing Redis pool for {url}: {e}")
cls._pools.clear()
@final
@dataclass
class RedisKVStorage(BaseKVStorage):
@ -33,19 +74,28 @@ class RedisKVStorage(BaseKVStorage):
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
# Create a connection pool with limits
self._pool = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
# Use shared connection pool
self._pool = RedisConnectionManager.get_pool(redis_url)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
)
async def initialize(self):
"""Initialize Redis connection and migrate legacy cache structure if needed"""
# Test connection
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(f"Connected to Redis for namespace {self.namespace}")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
# Migrate legacy cache structure if this is a cache namespace
if self.namespace.endswith("_cache"):
await self._migrate_legacy_cache_structure()
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
@ -82,7 +132,13 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
if data:
result = json.loads(data)
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
return result
return None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
@ -94,35 +150,113 @@ class RedisKVStorage(BaseKVStorage):
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
processed_results = []
for result in results:
if result:
data = json.loads(result)
# Ensure time fields are present for all documents
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
processed_results.append(data)
else:
processed_results.append(None)
return processed_results
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids)
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
async with self._get_redis_connection() as redis:
try:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
if not keys:
return {}
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Build result dictionary
result = {}
for key, value in zip(keys, values):
if value:
# Extract the ID part (after namespace:)
key_id = key.split(":", 1)[1]
try:
data = json.loads(value)
# Ensure time fields are present for all documents
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
result[key_id] = data
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for key {key}: {e}")
continue
return result
except Exception as e:
logger.error(f"Error getting all data from Redis: {e}")
return {}
async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for key in keys:
keys_list = list(keys) # Convert set to list for indexing
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data:
return
logger.info(f"Inserting {len(data)} items to {self.namespace}")
import time
current_time = int(time.time()) # Get current Unix timestamp
async with self._get_redis_connection() as redis:
try:
# Check which keys already exist to determine create vs update
pipe = redis.pipeline()
for k in data.keys():
pipe.exists(f"{self.namespace}:{k}")
exists_results = await pipe.execute()
# Add timestamps to data
for i, (k, v) in enumerate(data.items()):
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Add timestamps based on whether key exists
if exists_results[i]: # Key exists, only update update_time
v["update_time"] = current_time
else: # New key, set both create_time and update_time
v["create_time"] = current_time
v["update_time"] = current_time
v["_id"] = k
# Store the data
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k in data:
data[k]["_id"] = k
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
@ -148,13 +282,13 @@ class RedisKVStorage(BaseKVStorage):
)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
"""Delete specific records from storage by cache mode
Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis
Args:
modes (list[str]): List of cache mode to be drop from storage
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
@ -164,9 +298,47 @@ class RedisKVStorage(BaseKVStorage):
return False
try:
await self.delete(modes)
async with self._get_redis_connection() as redis:
keys_to_delete = []
# Find matching keys for each mode using SCAN
for mode in modes:
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
pattern = f"{self.namespace}:{mode}:*"
cursor = 0
mode_keys = []
while True:
cursor, keys = await redis.scan(
cursor, match=pattern, count=1000
)
if keys:
mode_keys.extend(keys)
if cursor == 0:
break
keys_to_delete.extend(mode_keys)
logger.info(
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
)
if keys_to_delete:
# Batch delete
pipe = redis.pipeline()
for key in keys_to_delete:
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Dropped {deleted_count} cache entries for modes: {modes}"
)
else:
logger.warning(f"No cache entries found for modes: {modes}")
return True
except Exception:
except Exception as e:
logger.error(f"Error dropping cache by modes in Redis: {e}")
return False
async def drop(self) -> dict[str, str]:
@ -177,24 +349,370 @@ class RedisKVStorage(BaseKVStorage):
"""
async with self._get_redis_connection() as redis:
try:
keys = await redis.keys(f"{self.namespace}:*")
# Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
cursor = 0
deleted_count = 0
if keys:
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys:
# Delete keys in batches
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
return {
"status": "success",
"message": f"{deleted_count} keys dropped",
}
else:
logger.info(f"No keys found to drop in {self.namespace}")
return {"status": "success", "message": "no keys to drop"}
if cursor == 0:
break
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
return {
"status": "success",
"message": f"{deleted_count} keys dropped",
}
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self):
"""Migrate legacy nested cache structure to flattened structure for Redis
Redis already stores data in a flattened way, but we need to check for
legacy keys that might contain nested JSON structures and migrate them.
Early exit if any flattened key is found (indicating migration already done).
"""
from lightrag.utils import generate_cache_key
async with self._get_redis_connection() as redis:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
if not keys:
return
# Check if we have any flattened keys already - if so, skip migration
has_flattened_keys = False
keys_to_migrate = []
for key in keys:
# Extract the ID part (after namespace:)
key_id = key.split(":", 1)[1]
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
if ":" in key_id and len(key_id.split(":")) == 3:
has_flattened_keys = True
break # Early exit - migration already done
# Get the data to check if it's a legacy nested structure
data = await redis.get(key)
if data:
try:
parsed_data = json.loads(data)
# Check if this looks like a legacy cache mode with nested structure
if isinstance(parsed_data, dict) and all(
isinstance(v, dict) and "return" in v
for v in parsed_data.values()
):
keys_to_migrate.append((key, key_id, parsed_data))
except json.JSONDecodeError:
continue
# If we found any flattened keys, assume migration is already done
if has_flattened_keys:
logger.debug(
f"Found flattened cache keys in {self.namespace}, skipping migration"
)
return
if not keys_to_migrate:
return
# Perform migration
pipe = redis.pipeline()
migration_count = 0
for old_key, mode, nested_data in keys_to_migrate:
# Delete the old key
pipe.delete(old_key)
# Create new flattened keys
for cache_hash, cache_entry in nested_data.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
full_key = f"{self.namespace}:{flattened_key}"
pipe.set(full_key, json.dumps(cache_entry))
migration_count += 1
await pipe.execute()
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
)
@final
@dataclass
class RedisDocStatusStorage(DocStatusStorage):
"""Redis implementation of document status storage"""
def __post_init__(self):
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
# Use shared connection pool
self._pool = RedisConnectionManager.get_pool(redis_url)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
)
async def initialize(self):
"""Initialize Redis connection"""
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(
f"Connected to Redis for doc status namespace {self.namespace}"
)
except Exception as e:
logger.error(f"Failed to connect to Redis for doc status: {e}")
raise
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
try:
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
raise
except RedisError as e:
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
raise
except Exception as e:
logger.error(
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
)
raise
async def close(self):
"""Close the Redis connection."""
if hasattr(self, "_redis") and self._redis:
await self._redis.close()
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
async def __aenter__(self):
"""Support for async context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure Redis resources are cleaned up when exiting context."""
await self.close()
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
keys_list = list(keys)
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
for result_data in results:
if result_data:
try:
result.append(json.loads(result_data))
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in get_by_ids: {e}")
continue
except Exception as e:
logger.error(f"Error in get_by_ids: {e}")
return result
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus}
async with self._get_redis_connection() as redis:
try:
# Use SCAN to iterate through all keys in the namespace
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
)
if keys:
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Count statuses
for value in values:
if value:
try:
doc_data = json.loads(value)
status = doc_data.get("status")
if status in counts:
counts[status] += 1
except json.JSONDecodeError:
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting status counts: {e}")
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
result = {}
async with self._get_redis_connection() as redis:
try:
# Use SCAN to iterate through all keys in the namespace
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
)
if keys:
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Filter by status and create DocProcessingStatus objects
for key, value in zip(keys, values):
if value:
try:
doc_data = json.loads(value)
if doc_data.get("status") == status.value:
# Extract document ID from key
doc_id = key.split(":", 1)[1]
# Make a copy of the data to avoid modifying the original
data = doc_data.copy()
# If content is missing, use content_summary as content
if (
"content" not in data
and "content_summary" in data
):
data["content"] = data["content_summary"]
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e:
logger.error(
f"Error processing document {key}: {e}"
)
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting docs by status: {e}")
return result
async def index_done_callback(self) -> None:
"""Redis handles persistence automatically"""
pass
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Insert or update document status data"""
if not data:
return
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._get_redis_connection() as redis:
try:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
if "chunks_list" not in doc_data:
doc_data["chunks_list"] = []
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
async def delete(self, doc_ids: list[str]) -> None:
"""Delete specific records from storage by their IDs"""
if not doc_ids:
return
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for doc_id in doc_ids:
pipe.delete(f"{self.namespace}:{doc_id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
)
async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources"""
try:
async with self._get_redis_connection() as redis:
# Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys:
# Delete keys in batches
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0:
break
logger.info(
f"Dropped {deleted_count} doc status keys from {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping doc status {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View file

@ -22,6 +22,7 @@ from typing import (
Dict,
)
from lightrag.constants import (
DEFAULT_MAX_GLEANING,
DEFAULT_MAX_TOKEN_SUMMARY,
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
)
@ -124,7 +125,9 @@ class LightRAG:
# Entity extraction
# ---
entity_extract_max_gleaning: int = field(default=1)
entity_extract_max_gleaning: int = field(
default=get_env_value("MAX_GLEANING", DEFAULT_MAX_GLEANING, int)
)
"""Maximum number of entity extraction attempts for ambiguous content."""
summary_to_max_tokens: int = field(
@ -346,6 +349,7 @@ class LightRAG:
# Fix global_config now
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@ -394,13 +398,13 @@ class LightRAG:
embedding_func=self.embedding_func,
)
# TODO: deprecating, text_chunks is redundant with chunks_vdb
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
@ -949,6 +953,7 @@ class LightRAG:
**dp,
"full_doc_id": doc_id,
"file_path": file_path, # Add file path to each chunk
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
}
for dp in self.chunking_func(
self.tokenizer,
@ -960,14 +965,17 @@ class LightRAG:
)
}
# Process document (text chunks and full docs) in parallel
# Create tasks with references for potential cancellation
# Process document in two stages
# Stage 1: Process text chunks and docs (parallel execution)
doc_status_task = asyncio.create_task(
self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSING,
"chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # Save chunks list
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
@ -983,11 +991,6 @@ class LightRAG:
chunks_vdb_task = asyncio.create_task(
self.chunks_vdb.upsert(chunks)
)
entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock
)
)
full_docs_task = asyncio.create_task(
self.full_docs.upsert(
{doc_id: {"content": status_doc.content}}
@ -996,14 +999,26 @@ class LightRAG:
text_chunks_task = asyncio.create_task(
self.text_chunks.upsert(chunks)
)
tasks = [
# First stage tasks (parallel execution)
first_stage_tasks = [
doc_status_task,
chunks_vdb_task,
entity_relation_task,
full_docs_task,
text_chunks_task,
]
await asyncio.gather(*tasks)
entity_relation_task = None
# Execute first stage tasks
await asyncio.gather(*first_stage_tasks)
# Stage 2: Process entity relation graph (after text_chunks are saved)
entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock
)
)
await entity_relation_task
file_extraction_stage_ok = True
except Exception as e:
@ -1018,14 +1033,14 @@ class LightRAG:
)
pipeline_status["history_messages"].append(error_msg)
# Cancel other tasks as they are no longer meaningful
for task in [
chunks_vdb_task,
entity_relation_task,
full_docs_task,
text_chunks_task,
]:
if not task.done():
# Cancel tasks that are not yet completed
all_tasks = first_stage_tasks + (
[entity_relation_task]
if entity_relation_task
else []
)
for task in all_tasks:
if task and not task.done():
task.cancel()
# Persistent llm cache
@ -1075,6 +1090,9 @@ class LightRAG:
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # 保留 chunks_list
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
@ -1193,6 +1211,7 @@ class LightRAG:
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
text_chunks_storage=self.text_chunks,
)
return chunk_results
except Exception as e:
@ -1723,28 +1742,10 @@ class LightRAG:
file_path="",
)
# 2. Get all chunks related to this document
try:
all_chunks = await self.text_chunks.get_all()
related_chunks = {
chunk_id: chunk_data
for chunk_id, chunk_data in all_chunks.items()
if isinstance(chunk_data, dict)
and chunk_data.get("full_doc_id") == doc_id
}
# 2. Get chunk IDs from document status
chunk_ids = set(doc_status_data.get("chunks_list", []))
# Update pipeline status after getting chunks count
async with pipeline_status_lock:
log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}")
raise Exception(f"Failed to retrieve document chunks: {e}") from e
if not related_chunks:
if not chunk_ids:
logger.warning(f"No chunks found for document {doc_id}")
# Mark that deletion operations have started
deletion_operations_started = True
@ -1775,7 +1776,6 @@ class LightRAG:
file_path=file_path,
)
chunk_ids = set(related_chunks.keys())
# Mark that deletion operations have started
deletion_operations_started = True
@ -1799,26 +1799,12 @@ class LightRAG:
)
)
# Update pipeline status after getting affected_nodes
async with pipeline_status_lock:
log_message = f"Found {len(affected_nodes)} affected entities"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
affected_edges = (
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
list(chunk_ids)
)
)
# Update pipeline status after getting affected_edges
async with pipeline_status_lock:
log_message = f"Found {len(affected_edges)} affected relations"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to analyze affected graph elements: {e}")
raise Exception(f"Failed to analyze graph dependencies: {e}") from e
@ -1836,6 +1822,14 @@ class LightRAG:
elif remaining_sources != sources:
entities_to_rebuild[node_label] = remaining_sources
async with pipeline_status_lock:
log_message = (
f"Found {len(entities_to_rebuild)} affected entities"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process relationships
for edge_data in affected_edges:
src = edge_data.get("source")
@ -1857,6 +1851,14 @@ class LightRAG:
elif remaining_sources != sources:
relationships_to_rebuild[edge_tuple] = remaining_sources
async with pipeline_status_lock:
log_message = (
f"Found {len(relationships_to_rebuild)} affected relations"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to process graph analysis results: {e}")
raise Exception(f"Failed to process graph dependencies: {e}") from e
@ -1940,17 +1942,13 @@ class LightRAG:
knowledge_graph_inst=self.chunk_entity_relation_graph,
entities_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
text_chunks=self.text_chunks,
text_chunks_storage=self.text_chunks,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
)
async with pipeline_status_lock:
log_message = f"Successfully rebuilt {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relations"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to rebuild knowledge from chunks: {e}")
raise Exception(

View file

@ -25,6 +25,7 @@ from .utils import (
CacheData,
get_conversation_turns,
use_llm_func_with_cache,
update_chunk_cache_list,
)
from .base import (
BaseGraphStorage,
@ -103,8 +104,6 @@ async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
) -> str:
"""Handle entity relation summary
@ -247,9 +246,11 @@ async def _rebuild_knowledge_from_chunks(
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks: BaseKVStorage,
text_chunks_storage: BaseKVStorage,
llm_response_cache: BaseKVStorage,
global_config: dict[str, str],
pipeline_status: dict | None = None,
pipeline_status_lock=None,
) -> None:
"""Rebuild entity and relationship descriptions from cached extraction results
@ -259,9 +260,12 @@ async def _rebuild_knowledge_from_chunks(
Args:
entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data}
"""
if not entities_to_rebuild and not relationships_to_rebuild:
return
rebuilt_entities_count = 0
rebuilt_relationships_count = 0
# Get all referenced chunk IDs
all_referenced_chunk_ids = set()
@ -270,36 +274,74 @@ async def _rebuild_knowledge_from_chunks(
for chunk_ids in relationships_to_rebuild.values():
all_referenced_chunk_ids.update(chunk_ids)
logger.debug(
f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
)
status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
# Get cached extraction results for these chunks
# Get cached extraction results for these chunks using storage
# cached_results chunk_id -> [list of extraction result from LLM cache sorted by created_at]
cached_results = await _get_cached_extraction_results(
llm_response_cache, all_referenced_chunk_ids
llm_response_cache,
all_referenced_chunk_ids,
text_chunks_storage=text_chunks_storage,
)
if not cached_results:
logger.warning("No cached extraction results found, cannot rebuild")
status_message = "No cached extraction results found, cannot rebuild"
logger.warning(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
return
# Process cached results to get entities and relationships for each chunk
chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
for chunk_id, extraction_result in cached_results.items():
for chunk_id, extraction_results in cached_results.items():
try:
entities, relationships = await _parse_extraction_result(
text_chunks=text_chunks,
extraction_result=extraction_result,
chunk_id=chunk_id,
)
chunk_entities[chunk_id] = entities
chunk_relationships[chunk_id] = relationships
# Handle multiple extraction results per chunk
chunk_entities[chunk_id] = defaultdict(list)
chunk_relationships[chunk_id] = defaultdict(list)
# process multiple LLM extraction results for a single chunk_id
for extraction_result in extraction_results:
entities, relationships = await _parse_extraction_result(
text_chunks_storage=text_chunks_storage,
extraction_result=extraction_result,
chunk_id=chunk_id,
)
# Merge entities and relationships from this extraction result
# Only keep the first occurrence of each entity_name in the same chunk_id
for entity_name, entity_list in entities.items():
if (
entity_name not in chunk_entities[chunk_id]
or len(chunk_entities[chunk_id][entity_name]) == 0
):
chunk_entities[chunk_id][entity_name].extend(entity_list)
# Only keep the first occurrence of each rel_key in the same chunk_id
for rel_key, rel_list in relationships.items():
if (
rel_key not in chunk_relationships[chunk_id]
or len(chunk_relationships[chunk_id][rel_key]) == 0
):
chunk_relationships[chunk_id][rel_key].extend(rel_list)
except Exception as e:
logger.error(
status_message = (
f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
)
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
continue
# Rebuild entities
@ -314,11 +356,22 @@ async def _rebuild_knowledge_from_chunks(
llm_response_cache=llm_response_cache,
global_config=global_config,
)
logger.debug(
f"Rebuilt entity {entity_name} from {len(chunk_ids)} cached extractions"
rebuilt_entities_count += 1
status_message = (
f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks"
)
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
except Exception as e:
logger.error(f"Failed to rebuild entity {entity_name}: {e}")
status_message = f"Failed to rebuild entity {entity_name}: {e}"
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
# Rebuild relationships
for (src, tgt), chunk_ids in relationships_to_rebuild.items():
@ -333,53 +386,112 @@ async def _rebuild_knowledge_from_chunks(
llm_response_cache=llm_response_cache,
global_config=global_config,
)
logger.debug(
f"Rebuilt relationship {src}-{tgt} from {len(chunk_ids)} cached extractions"
rebuilt_relationships_count += 1
status_message = (
f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks"
)
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
except Exception as e:
logger.error(f"Failed to rebuild relationship {src}-{tgt}: {e}")
status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
logger.debug("Completed rebuilding knowledge from cached extractions")
status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
async def _get_cached_extraction_results(
llm_response_cache: BaseKVStorage, chunk_ids: set[str]
) -> dict[str, str]:
llm_response_cache: BaseKVStorage,
chunk_ids: set[str],
text_chunks_storage: BaseKVStorage,
) -> dict[str, list[str]]:
"""Get cached extraction results for specific chunk IDs
Args:
llm_response_cache: LLM response cache storage
chunk_ids: Set of chunk IDs to get cached results for
text_chunks_data: Pre-loaded chunk data (optional, for performance)
text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
Returns:
Dict mapping chunk_id -> extraction_result_text
Dict mapping chunk_id -> list of extraction_result_text
"""
cached_results = {}
# Get all cached data for "default" mode (entity extraction cache)
default_cache = await llm_response_cache.get_by_id("default") or {}
# Collect all LLM cache IDs from chunks
all_cache_ids = set()
for cache_key, cache_entry in default_cache.items():
# Read from storage
chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
if chunk_data and isinstance(chunk_data, dict):
llm_cache_list = chunk_data.get("llm_cache_list", [])
if llm_cache_list:
all_cache_ids.update(llm_cache_list)
else:
logger.warning(
f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
)
if not all_cache_ids:
logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
return cached_results
# Batch get LLM cache entries
cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids))
# Process cache entries and group by chunk_id
valid_entries = 0
for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
if (
isinstance(cache_entry, dict)
cache_entry is not None
and isinstance(cache_entry, dict)
and cache_entry.get("cache_type") == "extract"
and cache_entry.get("chunk_id") in chunk_ids
):
chunk_id = cache_entry["chunk_id"]
extraction_result = cache_entry["return"]
cached_results[chunk_id] = extraction_result
create_time = cache_entry.get(
"create_time", 0
) # Get creation time, default to 0
valid_entries += 1
logger.debug(
f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs"
# Support multiple LLM caches per chunk
if chunk_id not in cached_results:
cached_results[chunk_id] = []
# Store tuple with extraction result and creation time for sorting
cached_results[chunk_id].append((extraction_result, create_time))
# Sort extraction results by create_time for each chunk
for chunk_id in cached_results:
# Sort by create_time (x[1]), then extract only extraction_result (x[0])
cached_results[chunk_id].sort(key=lambda x: x[1])
cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
logger.info(
f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
)
return cached_results
async def _parse_extraction_result(
text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str
text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str
) -> tuple[dict, dict]:
"""Parse cached extraction result using the same logic as extract_entities
Args:
text_chunks_storage: Text chunks storage to get chunk data
extraction_result: The cached LLM extraction result
chunk_id: The chunk ID for source tracking
@ -387,8 +499,8 @@ async def _parse_extraction_result(
Tuple of (entities_dict, relationships_dict)
"""
# Get chunk data for file_path
chunk_data = await text_chunks.get_by_id(chunk_id)
# Get chunk data for file_path from storage
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
file_path = (
chunk_data.get("file_path", "unknown_source")
if chunk_data
@ -761,8 +873,6 @@ async def _merge_nodes_then_upsert(
entity_name,
description,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
else:
@ -925,8 +1035,6 @@ async def _merge_edges_then_upsert(
f"({src_id}, {tgt_id})",
description,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
else:
@ -1102,6 +1210,7 @@ async def extract_entities(
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None,
) -> list:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -1208,6 +1317,9 @@ async def extract_entities(
# Get file path from chunk data or use default
file_path = chunk_dp.get("file_path", "unknown_source")
# Create cache keys collector for batch processing
cache_keys_collector = []
# Get initial extraction
hint_prompt = entity_extract_prompt.format(
**{**context_base, "input_text": content}
@ -1219,7 +1331,10 @@ async def extract_entities(
llm_response_cache=llm_response_cache,
cache_type="extract",
chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
)
# Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# Process initial extraction with file path
@ -1236,6 +1351,7 @@ async def extract_entities(
history_messages=history,
cache_type="extract",
chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@ -1266,11 +1382,21 @@ async def extract_entities(
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
cache_keys_collector=cache_keys_collector,
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
# Batch update chunk's llm_cache_list with all collected cache keys
if cache_keys_collector and text_chunks_storage:
await update_chunk_cache_list(
chunk_key,
text_chunks_storage,
cache_keys_collector,
"entity_extraction",
)
processed_chunks += 1
entities_count = len(maybe_nodes)
relations_count = len(maybe_edges)
@ -1343,7 +1469,7 @@ async def kg_query(
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
@ -1390,7 +1516,7 @@ async def kg_query(
)
if query_param.only_need_context:
return context
return context if context is not None else PROMPTS["fail_response"]
if context is None:
return PROMPTS["fail_response"]
@ -1502,7 +1628,7 @@ async def extract_keywords_only(
"""
# 1. Handle cache if needed - add cache type for keywords
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
args_hash = compute_args_hash(param.mode, text)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
)
@ -1647,7 +1773,7 @@ async def _get_vector_context(
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
logger.info(
f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
)
if not maybe_trun_chunks:
@ -1871,7 +1997,7 @@ async def _get_node_data(
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
)
# build prompt
@ -2180,7 +2306,7 @@ async def _get_edge_data(
),
)
logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
)
relations_context = []
@ -2369,7 +2495,7 @@ async def naive_query(
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
@ -2485,7 +2611,7 @@ async def kg_query_with_keywords(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)

View file

@ -14,7 +14,6 @@ from functools import wraps
from hashlib import md5
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import numpy as np
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv
from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES,
@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
raise e from None
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
def compute_args_hash(*args: Any) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
Returns:
str: Hash string
"""
@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
# Convert all arguments to strings and join them
args_str = "".join([str(arg) for arg in args])
if cache_type:
args_str = f"{cache_type}:{args_str}"
# Compute MD5 hash
return hashlib.md5(args_str.encode()).hexdigest()
def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
"""Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
Args:
mode: Cache mode (e.g., 'default', 'local', 'global')
cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
hash_value: Hash value from compute_args_hash
Returns:
str: Flattened cache key
"""
return f"{mode}:{cache_type}:{hash_value}"
def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
"""Parse a flattened cache key back into its components
Args:
cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
Returns:
tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
"""
parts = cache_key.split(":", 2)
if len(parts) == 3:
return parts[0], parts[1], parts[2]
return None
def compute_mdhash_id(content: str, prefix: str = "") -> str:
"""
Compute a unique ID for a given content string.
@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
return combined_data
async def get_best_cached_response(
hashing_kv,
current_embedding,
similarity_threshold=0.95,
mode="default",
use_llm_check=False,
llm_func=None,
original_prompt=None,
cache_type=None,
) -> str | None:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
best_similarity = -1
best_response = None
best_prompt = None
best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
# Skip if cache_type doesn't match
if cache_type and cache_data.get("cache_type") != cache_type:
continue
# Check if cache data is valid
if cache_data["embedding"] is None:
continue
try:
# Safely convert cached embedding
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
# Ensure min_val and max_val are valid float values
embedding_min = cache_data.get("embedding_min")
embedding_max = cache_data.get("embedding_max")
if (
embedding_min is None
or embedding_max is None
or embedding_min >= embedding_max
):
logger.warning(
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
)
continue
cached_embedding = dequantize_embedding(
cached_quantized,
embedding_min,
embedding_max,
)
except Exception as e:
logger.warning(f"Error processing cached embedding: {str(e)}")
continue
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided
if (
use_llm_check
and llm_func
and original_prompt
and best_prompt
and best_response is not None
):
compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt
)
try:
llm_result = await llm_func(compare_prompt)
llm_result = llm_result.strip()
llm_similarity = float(llm_result)
# Replace vector similarity with LLM similarity score
best_similarity = llm_similarity
if best_similarity < similarity_threshold:
log_data = {
"event": "cache_rejected_by_llm",
"type": cache_type,
"mode": mode,
"original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100
else original_prompt,
"cached_question": best_prompt[:100] + "..."
if len(best_prompt) > 100
else best_prompt,
"similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
return None
except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}")
return None # Return None directly when LLM check fails
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
log_data = {
"event": "cache_hit",
"type": cache_type,
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
return best_response
return None
def cosine_similarity(v1, v2):
"""Calculate cosine similarity between two vectors"""
dot_product = np.dot(v1, v2)
@ -957,7 +857,7 @@ async def handle_cache(
mode="default",
cache_type=None,
):
"""Generic cache handling function"""
"""Generic cache handling function with flattened cache keys"""
if hashing_kv is None:
return None, None, None, None
@ -968,15 +868,14 @@ async def handle_cache(
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
return None, None, None, None
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
return mode_cache[args_hash]["return"], None, None, None
# Use flattened cache key format: {mode}:{cache_type}:{hash}
flattened_key = generate_cache_key(mode, cache_type, args_hash)
cache_entry = await hashing_kv.get_by_id(flattened_key)
if cache_entry:
logger.debug(f"Flattened cache hit(key:{flattened_key})")
return cache_entry["return"], None, None, None
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
return None, None, None, None
@ -994,7 +893,7 @@ class CacheData:
async def save_to_cache(hashing_kv, cache_data: CacheData):
"""Save data to cache, with improved handling for streaming responses and duplicate content.
"""Save data to cache using flattened key structure.
Args:
hashing_kv: The key-value storage for caching
@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
logger.debug("Streaming response detected, skipping cache")
return
# Get existing cache data
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = (
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
or {}
)
else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
# Use flattened cache key format: {mode}:{cache_type}:{hash}
flattened_key = generate_cache_key(
cache_data.mode, cache_data.cache_type, cache_data.args_hash
)
# Check if we already have identical content cached
if cache_data.args_hash in mode_cache:
existing_content = mode_cache[cache_data.args_hash].get("return")
existing_cache = await hashing_kv.get_by_id(flattened_key)
if existing_cache:
existing_content = existing_cache.get("return")
if existing_content == cache_data.content:
logger.info(
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
)
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
return
# Update cache with new content
mode_cache[cache_data.args_hash] = {
# Create cache entry with flattened structure
cache_entry = {
"return": cache_data.content,
"cache_type": cache_data.cache_type,
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
"original_prompt": cache_data.prompt,
}
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
logger.info(f" == LLM cache == saving: {flattened_key}")
# Only upsert if there's actual new content
await hashing_kv.upsert({cache_data.mode: mode_cache})
# Save using flattened key
await hashing_kv.upsert({flattened_key: cache_entry})
def safe_unicode_decode(content):
@ -1529,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
return import_class
async def update_chunk_cache_list(
chunk_id: str,
text_chunks_storage: "BaseKVStorage",
cache_keys: list[str],
cache_scenario: str = "batch_update",
) -> None:
"""Update chunk's llm_cache_list with the given cache keys
Args:
chunk_id: Chunk identifier
text_chunks_storage: Text chunks storage instance
cache_keys: List of cache keys to add to the list
cache_scenario: Description of the cache scenario for logging
"""
if not cache_keys:
return
try:
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
if chunk_data:
# Ensure llm_cache_list exists
if "llm_cache_list" not in chunk_data:
chunk_data["llm_cache_list"] = []
# Add cache keys to the list if not already present
existing_keys = set(chunk_data["llm_cache_list"])
new_keys = [key for key in cache_keys if key not in existing_keys]
if new_keys:
chunk_data["llm_cache_list"].extend(new_keys)
# Update the chunk in storage
await text_chunks_storage.upsert({chunk_id: chunk_data})
logger.debug(
f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})"
)
except Exception as e:
logger.warning(
f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}"
)
async def use_llm_func_with_cache(
input_text: str,
use_llm_func: callable,
@ -1537,6 +1473,7 @@ async def use_llm_func_with_cache(
history_messages: list[dict[str, str]] = None,
cache_type: str = "extract",
chunk_id: str | None = None,
cache_keys_collector: list = None,
) -> str:
"""Call LLM function with cache support
@ -1551,6 +1488,8 @@ async def use_llm_func_with_cache(
history_messages: History messages list
cache_type: Type of cache
chunk_id: Chunk identifier to store in cache
text_chunks_storage: Text chunks storage to update llm_cache_list
cache_keys_collector: Optional list to collect cache keys for batch processing
Returns:
LLM response text
@ -1563,6 +1502,9 @@ async def use_llm_func_with_cache(
_prompt = input_text
arg_hash = compute_args_hash(_prompt)
# Generate cache key for this LLM call
cache_key = generate_cache_key("default", cache_type, arg_hash)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache,
arg_hash,
@ -1573,6 +1515,11 @@ async def use_llm_func_with_cache(
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
# Add cache key to collector if provided
if cache_keys_collector is not None:
cache_keys_collector.append(cache_key)
return cached_return
statistic_data["llm_call"] += 1
@ -1597,6 +1544,10 @@ async def use_llm_func_with_cache(
),
)
# Add cache key to collector if provided
if cache_keys_collector is not None:
cache_keys_collector.append(cache_key)
return res
# When cache is disabled, directly call LLM

View file

@ -6,7 +6,7 @@ from typing import Any, cast
from .base import DeletionResult
from .kg.shared_storage import get_graph_db_lock
from .prompt import GRAPH_FIELD_SEP
from .constants import GRAPH_FIELD_SEP
from .utils import compute_mdhash_id, logger
from .base import StorageNameSpace

View file

@ -57,6 +57,10 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation here]"
}},
"Diversity": {{
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation here]"
}},
"Empowerment": {{
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation here]"

View file

@ -8,6 +8,7 @@
支持的图存储类型包括
- NetworkXStorage
- Neo4JStorage
- MongoDBStorage
- PGGraphStorage
- MemgraphStorage
"""