Merge branch 'main' into add-Memgraph-graph-db
This commit is contained in:
commit
a69194c079
28 changed files with 3042 additions and 793 deletions
18
env.example
18
env.example
|
|
@ -58,6 +58,8 @@ SUMMARY_LANGUAGE=English
|
||||||
# FORCE_LLM_SUMMARY_ON_MERGE=6
|
# FORCE_LLM_SUMMARY_ON_MERGE=6
|
||||||
### Max tokens for entity/relations description after merge
|
### Max tokens for entity/relations description after merge
|
||||||
# MAX_TOKEN_SUMMARY=500
|
# MAX_TOKEN_SUMMARY=500
|
||||||
|
### Maximum number of entity extraction attempts for ambiguous content
|
||||||
|
# MAX_GLEANING=1
|
||||||
|
|
||||||
### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
|
### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
|
||||||
# MAX_PARALLEL_INSERT=2
|
# MAX_PARALLEL_INSERT=2
|
||||||
|
|
@ -112,15 +114,6 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||||
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
|
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
|
||||||
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
|
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
|
||||||
|
|
||||||
### TiDB Configuration (Deprecated)
|
|
||||||
# TIDB_HOST=localhost
|
|
||||||
# TIDB_PORT=4000
|
|
||||||
# TIDB_USER=your_username
|
|
||||||
# TIDB_PASSWORD='your_password'
|
|
||||||
# TIDB_DATABASE=your_database
|
|
||||||
### separating all data from difference Lightrag instances(deprecating)
|
|
||||||
# TIDB_WORKSPACE=default
|
|
||||||
|
|
||||||
### PostgreSQL Configuration
|
### PostgreSQL Configuration
|
||||||
POSTGRES_HOST=localhost
|
POSTGRES_HOST=localhost
|
||||||
POSTGRES_PORT=5432
|
POSTGRES_PORT=5432
|
||||||
|
|
@ -128,7 +121,7 @@ POSTGRES_USER=your_username
|
||||||
POSTGRES_PASSWORD='your_password'
|
POSTGRES_PASSWORD='your_password'
|
||||||
POSTGRES_DATABASE=your_database
|
POSTGRES_DATABASE=your_database
|
||||||
POSTGRES_MAX_CONNECTIONS=12
|
POSTGRES_MAX_CONNECTIONS=12
|
||||||
### separating all data from difference Lightrag instances(deprecating)
|
### separating all data from difference Lightrag instances
|
||||||
# POSTGRES_WORKSPACE=default
|
# POSTGRES_WORKSPACE=default
|
||||||
|
|
||||||
### Neo4j Configuration
|
### Neo4j Configuration
|
||||||
|
|
@ -144,14 +137,15 @@ NEO4J_PASSWORD='your_password'
|
||||||
# AGE_POSTGRES_PORT=8529
|
# AGE_POSTGRES_PORT=8529
|
||||||
|
|
||||||
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
||||||
### AGE_GRAPH_NAME is precated
|
### AGE_GRAPH_NAME is deprecated
|
||||||
# AGE_GRAPH_NAME=lightrag
|
# AGE_GRAPH_NAME=lightrag
|
||||||
|
|
||||||
### MongoDB Configuration
|
### MongoDB Configuration
|
||||||
MONGO_URI=mongodb://root:root@localhost:27017/
|
MONGO_URI=mongodb://root:root@localhost:27017/
|
||||||
MONGO_DATABASE=LightRAG
|
MONGO_DATABASE=LightRAG
|
||||||
### separating all data from difference Lightrag instances(deprecating)
|
### separating all data from difference Lightrag instances(deprecating)
|
||||||
# MONGODB_GRAPH=false
|
### separating all data from difference Lightrag instances
|
||||||
|
# MONGODB_WORKSPACE=default
|
||||||
|
|
||||||
### Milvus Configuration
|
### Milvus Configuration
|
||||||
MILVUS_URI=http://localhost:19530
|
MILVUS_URI=http://localhost:19530
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,74 @@ This example shows how to:
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root directory to Python path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
|
||||||
from raganything.raganything import RAGAnything
|
from raganything import RAGAnything, RAGAnythingConfig
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging():
|
||||||
|
"""Configure logging for the application"""
|
||||||
|
# Get log directory path from environment variable or use current directory
|
||||||
|
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
||||||
|
log_file_path = os.path.abspath(os.path.join(log_dir, "raganything_example.log"))
|
||||||
|
|
||||||
|
print(f"\nRAGAnything example log file: {log_file_path}\n")
|
||||||
|
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
|
||||||
|
|
||||||
|
# Get log file max size and backup count from environment variables
|
||||||
|
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
||||||
|
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
||||||
|
|
||||||
|
logging.config.dictConfig(
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"formatters": {
|
||||||
|
"default": {
|
||||||
|
"format": "%(levelname)s: %(message)s",
|
||||||
|
},
|
||||||
|
"detailed": {
|
||||||
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"handlers": {
|
||||||
|
"console": {
|
||||||
|
"formatter": "default",
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"stream": "ext://sys.stderr",
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"formatter": "detailed",
|
||||||
|
"class": "logging.handlers.RotatingFileHandler",
|
||||||
|
"filename": log_file_path,
|
||||||
|
"maxBytes": log_max_bytes,
|
||||||
|
"backupCount": log_backup_count,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"loggers": {
|
||||||
|
"lightrag": {
|
||||||
|
"handlers": ["console", "file"],
|
||||||
|
"level": "INFO",
|
||||||
|
"propagate": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the logger level to INFO
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
# Enable verbose debug if needed
|
||||||
|
set_verbose_debug(os.getenv("VERBOSE", "false").lower() == "true")
|
||||||
|
|
||||||
|
|
||||||
async def process_with_rag(
|
async def process_with_rag(
|
||||||
|
|
@ -31,15 +96,21 @@ async def process_with_rag(
|
||||||
output_dir: Output directory for RAG results
|
output_dir: Output directory for RAG results
|
||||||
api_key: OpenAI API key
|
api_key: OpenAI API key
|
||||||
base_url: Optional base URL for API
|
base_url: Optional base URL for API
|
||||||
|
working_dir: Working directory for RAG storage
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Initialize RAGAnything
|
# Create RAGAnything configuration
|
||||||
rag = RAGAnything(
|
config = RAGAnythingConfig(
|
||||||
working_dir=working_dir,
|
working_dir=working_dir or "./rag_storage",
|
||||||
llm_model_func=lambda prompt,
|
mineru_parse_method="auto",
|
||||||
system_prompt=None,
|
enable_image_processing=True,
|
||||||
history_messages=[],
|
enable_table_processing=True,
|
||||||
**kwargs: openai_complete_if_cache(
|
enable_equation_processing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define LLM model function
|
||||||
|
def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
|
||||||
|
return openai_complete_if_cache(
|
||||||
"gpt-4o-mini",
|
"gpt-4o-mini",
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
|
@ -47,81 +118,123 @@ async def process_with_rag(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
),
|
|
||||||
vision_model_func=lambda prompt,
|
|
||||||
system_prompt=None,
|
|
||||||
history_messages=[],
|
|
||||||
image_data=None,
|
|
||||||
**kwargs: openai_complete_if_cache(
|
|
||||||
"gpt-4o",
|
|
||||||
"",
|
|
||||||
system_prompt=None,
|
|
||||||
history_messages=[],
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": system_prompt}
|
|
||||||
if system_prompt
|
|
||||||
else None,
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": prompt},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_data}"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
if image_data
|
|
||||||
else {"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
if image_data
|
|
||||||
else openai_complete_if_cache(
|
# Define vision model function for image processing
|
||||||
"gpt-4o-mini",
|
def vision_model_func(
|
||||||
prompt,
|
prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs
|
||||||
system_prompt=system_prompt,
|
):
|
||||||
history_messages=history_messages,
|
if image_data:
|
||||||
api_key=api_key,
|
return openai_complete_if_cache(
|
||||||
base_url=base_url,
|
"gpt-4o",
|
||||||
**kwargs,
|
"",
|
||||||
),
|
system_prompt=None,
|
||||||
embedding_func=EmbeddingFunc(
|
history_messages=[],
|
||||||
embedding_dim=3072,
|
messages=[
|
||||||
max_token_size=8192,
|
{"role": "system", "content": system_prompt}
|
||||||
func=lambda texts: openai_embed(
|
if system_prompt
|
||||||
texts,
|
else None,
|
||||||
model="text-embedding-3-large",
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_data}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
if image_data
|
||||||
|
else {"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
),
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return llm_model_func(prompt, system_prompt, history_messages, **kwargs)
|
||||||
|
|
||||||
|
# Define embedding function
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=3072,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=lambda texts: openai_embed(
|
||||||
|
texts,
|
||||||
|
model="text-embedding-3-large",
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize RAGAnything with new dataclass structure
|
||||||
|
rag = RAGAnything(
|
||||||
|
config=config,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
vision_model_func=vision_model_func,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
)
|
||||||
|
|
||||||
# Process document
|
# Process document
|
||||||
await rag.process_document_complete(
|
await rag.process_document_complete(
|
||||||
file_path=file_path, output_dir=output_dir, parse_method="auto"
|
file_path=file_path, output_dir=output_dir, parse_method="auto"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example queries
|
# Example queries - demonstrating different query approaches
|
||||||
queries = [
|
logger.info("\nQuerying processed document:")
|
||||||
|
|
||||||
|
# 1. Pure text queries using aquery()
|
||||||
|
text_queries = [
|
||||||
"What is the main content of the document?",
|
"What is the main content of the document?",
|
||||||
"Describe the images and figures in the document",
|
"What are the key topics discussed?",
|
||||||
"Tell me about the experimental results and data tables",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
print("\nQuerying processed document:")
|
for query in text_queries:
|
||||||
for query in queries:
|
logger.info(f"\n[Text Query]: {query}")
|
||||||
print(f"\nQuery: {query}")
|
result = await rag.aquery(query, mode="hybrid")
|
||||||
result = await rag.query_with_multimodal(query, mode="hybrid")
|
logger.info(f"Answer: {result}")
|
||||||
print(f"Answer: {result}")
|
|
||||||
|
# 2. Multimodal query with specific multimodal content using aquery_with_multimodal()
|
||||||
|
logger.info(
|
||||||
|
"\n[Multimodal Query]: Analyzing performance data in context of document"
|
||||||
|
)
|
||||||
|
multimodal_result = await rag.aquery_with_multimodal(
|
||||||
|
"Compare this performance data with any similar results mentioned in the document",
|
||||||
|
multimodal_content=[
|
||||||
|
{
|
||||||
|
"type": "table",
|
||||||
|
"table_data": """Method,Accuracy,Processing_Time
|
||||||
|
RAGAnything,95.2%,120ms
|
||||||
|
Traditional_RAG,87.3%,180ms
|
||||||
|
Baseline,82.1%,200ms""",
|
||||||
|
"table_caption": "Performance comparison results",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
mode="hybrid",
|
||||||
|
)
|
||||||
|
logger.info(f"Answer: {multimodal_result}")
|
||||||
|
|
||||||
|
# 3. Another multimodal query with equation content
|
||||||
|
logger.info("\n[Multimodal Query]: Mathematical formula analysis")
|
||||||
|
equation_result = await rag.aquery_with_multimodal(
|
||||||
|
"Explain this formula and relate it to any mathematical concepts in the document",
|
||||||
|
multimodal_content=[
|
||||||
|
{
|
||||||
|
"type": "equation",
|
||||||
|
"latex": "F1 = 2 \\cdot \\frac{precision \\cdot recall}{precision + recall}",
|
||||||
|
"equation_caption": "F1-score calculation formula",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
mode="hybrid",
|
||||||
|
)
|
||||||
|
logger.info(f"Answer: {equation_result}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing with RAG: {str(e)}")
|
logger.error(f"Error processing with RAG: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -135,12 +248,20 @@ def main():
|
||||||
"--output", "-o", default="./output", help="Output directory path"
|
"--output", "-o", default="./output", help="Output directory path"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--api-key", required=True, help="OpenAI API key for RAG processing"
|
"--api-key",
|
||||||
|
default=os.getenv("OPENAI_API_KEY"),
|
||||||
|
help="OpenAI API key (defaults to OPENAI_API_KEY env var)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--base-url", help="Optional base URL for API")
|
parser.add_argument("--base-url", help="Optional base URL for API")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Check if API key is provided
|
||||||
|
if not args.api_key:
|
||||||
|
logger.error("Error: OpenAI API key is required")
|
||||||
|
logger.error("Set OPENAI_API_KEY environment variable or use --api-key option")
|
||||||
|
return
|
||||||
|
|
||||||
# Create output directory if specified
|
# Create output directory if specified
|
||||||
if args.output:
|
if args.output:
|
||||||
os.makedirs(args.output, exist_ok=True)
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
|
@ -154,4 +275,12 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Configure logging first
|
||||||
|
configure_logging()
|
||||||
|
|
||||||
|
print("RAGAnything Example")
|
||||||
|
print("=" * 30)
|
||||||
|
print("Processing document with multimodal RAG pipeline")
|
||||||
|
print("=" * 30)
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
|
||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get all cache data using the new flattened structure
|
||||||
|
all_data = await from_llm_response_cache.get_all()
|
||||||
|
|
||||||
|
# Convert flattened data to hierarchical structure for JsonKVStorage
|
||||||
kv = {}
|
kv = {}
|
||||||
for c_id in await from_llm_response_cache.all_keys():
|
for flattened_key, cache_entry in all_data.items():
|
||||||
print(f"Copying {c_id}")
|
# Parse flattened key: {mode}:{cache_type}:{hash}
|
||||||
workspace = c_id["workspace"]
|
parts = flattened_key.split(":", 2)
|
||||||
mode = c_id["mode"]
|
if len(parts) == 3:
|
||||||
_id = c_id["id"]
|
mode, cache_type, hash_value = parts
|
||||||
postgres_db.workspace = workspace
|
if mode not in kv:
|
||||||
obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
|
kv[mode] = {}
|
||||||
if mode not in kv:
|
kv[mode][hash_value] = cache_entry
|
||||||
kv[mode] = {}
|
print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
|
||||||
kv[mode][_id] = obj[_id]
|
else:
|
||||||
print(f"Object {obj}")
|
print(f"Skipping invalid key format: {flattened_key}")
|
||||||
|
|
||||||
await to_llm_response_cache.upsert(kv)
|
await to_llm_response_cache.upsert(kv)
|
||||||
await to_llm_response_cache.index_done_callback()
|
await to_llm_response_cache.index_done_callback()
|
||||||
print("Mission accomplished!")
|
print("Mission accomplished!")
|
||||||
|
|
@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
|
||||||
db=postgres_db,
|
db=postgres_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
for mode in await from_llm_response_cache.all_keys():
|
# Get all cache data from JsonKVStorage (hierarchical structure)
|
||||||
print(f"Copying {mode}")
|
all_data = await from_llm_response_cache.get_all()
|
||||||
caches = await from_llm_response_cache.get_by_id(mode)
|
|
||||||
for k, v in caches.items():
|
# Convert hierarchical data to flattened structure for PGKVStorage
|
||||||
item = {mode: {k: v}}
|
flattened_data = {}
|
||||||
print(f"\tCopying {item}")
|
for mode, mode_data in all_data.items():
|
||||||
await to_llm_response_cache.upsert(item)
|
print(f"Processing mode: {mode}")
|
||||||
|
for hash_value, cache_entry in mode_data.items():
|
||||||
|
# Determine cache_type from cache entry or use default
|
||||||
|
cache_type = cache_entry.get("cache_type", "extract")
|
||||||
|
# Create flattened key: {mode}:{cache_type}:{hash}
|
||||||
|
flattened_key = f"{mode}:{cache_type}:{hash_value}"
|
||||||
|
flattened_data[flattened_key] = cache_entry
|
||||||
|
print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
|
||||||
|
|
||||||
|
# Upsert the flattened data
|
||||||
|
await to_llm_response_cache.upsert(flattened_data)
|
||||||
|
print("Mission accomplished!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
__api_version__ = "0176"
|
__api_version__ = "0178"
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,51 @@ router = APIRouter(
|
||||||
temp_prefix = "__tmp__"
|
temp_prefix = "__tmp__"
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_filename(filename: str, input_dir: Path) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize uploaded filename to prevent Path Traversal attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: The original filename from the upload
|
||||||
|
input_dir: The target input directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized filename that is safe to use
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the filename is unsafe or invalid
|
||||||
|
"""
|
||||||
|
# Basic validation
|
||||||
|
if not filename or not filename.strip():
|
||||||
|
raise HTTPException(status_code=400, detail="Filename cannot be empty")
|
||||||
|
|
||||||
|
# Remove path separators and traversal sequences
|
||||||
|
clean_name = filename.replace("/", "").replace("\\", "")
|
||||||
|
clean_name = clean_name.replace("..", "")
|
||||||
|
|
||||||
|
# Remove control characters and null bytes
|
||||||
|
clean_name = "".join(c for c in clean_name if ord(c) >= 32 and c != "\x7f")
|
||||||
|
|
||||||
|
# Remove leading/trailing whitespace and dots
|
||||||
|
clean_name = clean_name.strip().strip(".")
|
||||||
|
|
||||||
|
# Check if anything is left after sanitization
|
||||||
|
if not clean_name:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid filename after sanitization"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the final path stays within the input directory
|
||||||
|
try:
|
||||||
|
final_path = (input_dir / clean_name).resolve()
|
||||||
|
if not final_path.is_relative_to(input_dir.resolve()):
|
||||||
|
raise HTTPException(status_code=400, detail="Unsafe filename detected")
|
||||||
|
except (OSError, ValueError):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||||
|
|
||||||
|
return clean_name
|
||||||
|
|
||||||
|
|
||||||
class ScanResponse(BaseModel):
|
class ScanResponse(BaseModel):
|
||||||
"""Response model for document scanning operation
|
"""Response model for document scanning operation
|
||||||
|
|
||||||
|
|
@ -783,7 +828,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||||
try:
|
try:
|
||||||
new_files = doc_manager.scan_directory_for_new_files()
|
new_files = doc_manager.scan_directory_for_new_files()
|
||||||
total_files = len(new_files)
|
total_files = len(new_files)
|
||||||
logger.info(f"Found {total_files} new files to index.")
|
logger.info(f"Found {total_files} files to index.")
|
||||||
|
|
||||||
if not new_files:
|
if not new_files:
|
||||||
return
|
return
|
||||||
|
|
@ -816,8 +861,13 @@ async def background_delete_documents(
|
||||||
successful_deletions = []
|
successful_deletions = []
|
||||||
failed_deletions = []
|
failed_deletions = []
|
||||||
|
|
||||||
# Set pipeline status to busy for deletion
|
# Double-check pipeline status before proceeding
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("busy", False):
|
||||||
|
logger.warning("Error: Unexpected pipeline busy state, aborting deletion.")
|
||||||
|
return # Abort deletion operation
|
||||||
|
|
||||||
|
# Set pipeline status to busy for deletion
|
||||||
pipeline_status.update(
|
pipeline_status.update(
|
||||||
{
|
{
|
||||||
"busy": True,
|
"busy": True,
|
||||||
|
|
@ -926,13 +976,26 @@ async def background_delete_documents(
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
pipeline_status["history_messages"].append(error_msg)
|
pipeline_status["history_messages"].append(error_msg)
|
||||||
finally:
|
finally:
|
||||||
# Final summary
|
# Final summary and check for pending requests
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
pipeline_status["busy"] = False
|
pipeline_status["busy"] = False
|
||||||
completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed"
|
completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed"
|
||||||
pipeline_status["latest_message"] = completion_msg
|
pipeline_status["latest_message"] = completion_msg
|
||||||
pipeline_status["history_messages"].append(completion_msg)
|
pipeline_status["history_messages"].append(completion_msg)
|
||||||
|
|
||||||
|
# Check if there are pending document indexing requests
|
||||||
|
has_pending_request = pipeline_status.get("request_pending", False)
|
||||||
|
|
||||||
|
# If there are pending requests, start document processing pipeline
|
||||||
|
if has_pending_request:
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
"Processing pending document indexing requests after deletion"
|
||||||
|
)
|
||||||
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing pending documents after deletion: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_document_routes(
|
def create_document_routes(
|
||||||
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||||
|
|
@ -986,18 +1049,21 @@ def create_document_routes(
|
||||||
HTTPException: If the file type is not supported (400) or other errors occur (500).
|
HTTPException: If the file type is not supported (400) or other errors occur (500).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not doc_manager.is_supported_file(file.filename):
|
# Sanitize filename to prevent Path Traversal attacks
|
||||||
|
safe_filename = sanitize_filename(file.filename, doc_manager.input_dir)
|
||||||
|
|
||||||
|
if not doc_manager.is_supported_file(safe_filename):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
|
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path = doc_manager.input_dir / file.filename
|
file_path = doc_manager.input_dir / safe_filename
|
||||||
# Check if file already exists
|
# Check if file already exists
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="duplicated",
|
status="duplicated",
|
||||||
message=f"File '{file.filename}' already exists in the input directory.",
|
message=f"File '{safe_filename}' already exists in the input directory.",
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(file_path, "wb") as buffer:
|
with open(file_path, "wb") as buffer:
|
||||||
|
|
@ -1008,7 +1074,7 @@ def create_document_routes(
|
||||||
|
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="success",
|
status="success",
|
||||||
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
|
message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
|
logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
|
||||||
|
|
|
||||||
|
|
@ -234,7 +234,7 @@ class OllamaAPI:
|
||||||
@self.router.get("/version", dependencies=[Depends(combined_auth)])
|
@self.router.get("/version", dependencies=[Depends(combined_auth)])
|
||||||
async def get_version():
|
async def get_version():
|
||||||
"""Get Ollama version information"""
|
"""Get Ollama version information"""
|
||||||
return OllamaVersionResponse(version="0.5.4")
|
return OllamaVersionResponse(version="0.9.3")
|
||||||
|
|
||||||
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
|
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
|
||||||
async def get_tags():
|
async def get_tags():
|
||||||
|
|
@ -244,9 +244,9 @@ class OllamaAPI:
|
||||||
{
|
{
|
||||||
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
|
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
||||||
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
||||||
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"details": {
|
"details": {
|
||||||
"parent_model": "",
|
"parent_model": "",
|
||||||
"format": "gguf",
|
"format": "gguf",
|
||||||
|
|
@ -337,7 +337,10 @@ class OllamaAPI:
|
||||||
data = {
|
data = {
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"response": "",
|
||||||
"done": True,
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"context": [],
|
||||||
"total_duration": total_time,
|
"total_duration": total_time,
|
||||||
"load_duration": 0,
|
"load_duration": 0,
|
||||||
"prompt_eval_count": prompt_tokens,
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
|
@ -377,6 +380,7 @@ class OllamaAPI:
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"response": f"\n\nError: {error_msg}",
|
"response": f"\n\nError: {error_msg}",
|
||||||
|
"error": f"\n\nError: {error_msg}",
|
||||||
"done": False,
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||||
|
|
@ -385,6 +389,7 @@ class OllamaAPI:
|
||||||
final_data = {
|
final_data = {
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"response": "",
|
||||||
"done": True,
|
"done": True,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||||
|
|
@ -399,7 +404,10 @@ class OllamaAPI:
|
||||||
data = {
|
data = {
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"response": "",
|
||||||
"done": True,
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"context": [],
|
||||||
"total_duration": total_time,
|
"total_duration": total_time,
|
||||||
"load_duration": 0,
|
"load_duration": 0,
|
||||||
"prompt_eval_count": prompt_tokens,
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
|
@ -444,6 +452,8 @@ class OllamaAPI:
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"response": str(response_text),
|
"response": str(response_text),
|
||||||
"done": True,
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"context": [],
|
||||||
"total_duration": total_time,
|
"total_duration": total_time,
|
||||||
"load_duration": 0,
|
"load_duration": 0,
|
||||||
"prompt_eval_count": prompt_tokens,
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
|
@ -557,6 +567,12 @@ class OllamaAPI:
|
||||||
data = {
|
data = {
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
"done": True,
|
"done": True,
|
||||||
"total_duration": total_time,
|
"total_duration": total_time,
|
||||||
"load_duration": 0,
|
"load_duration": 0,
|
||||||
|
|
@ -605,6 +621,7 @@ class OllamaAPI:
|
||||||
"content": f"\n\nError: {error_msg}",
|
"content": f"\n\nError: {error_msg}",
|
||||||
"images": None,
|
"images": None,
|
||||||
},
|
},
|
||||||
|
"error": f"\n\nError: {error_msg}",
|
||||||
"done": False,
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||||
|
|
@ -613,6 +630,11 @@ class OllamaAPI:
|
||||||
final_data = {
|
final_data = {
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
"done": True,
|
"done": True,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||||
|
|
@ -633,6 +655,7 @@ class OllamaAPI:
|
||||||
"content": "",
|
"content": "",
|
||||||
"images": None,
|
"images": None,
|
||||||
},
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
"done": True,
|
"done": True,
|
||||||
"total_duration": total_time,
|
"total_duration": total_time,
|
||||||
"load_duration": 0,
|
"load_duration": 0,
|
||||||
|
|
@ -697,6 +720,7 @@ class OllamaAPI:
|
||||||
"content": str(response_text),
|
"content": str(response_text),
|
||||||
"images": None,
|
"images": None,
|
||||||
},
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
"done": True,
|
"done": True,
|
||||||
"total_duration": total_time,
|
"total_duration": total_time,
|
||||||
"load_duration": 0,
|
"load_duration": 0,
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
# If it's a string, send it all at once
|
# If it's a string, send it all at once
|
||||||
yield f"{json.dumps({'response': response})}\n"
|
yield f"{json.dumps({'response': response})}\n"
|
||||||
|
elif response is None:
|
||||||
|
# Handle None response (e.g., when only_need_context=True but no context found)
|
||||||
|
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
|
||||||
else:
|
else:
|
||||||
# If it's an async generator, send chunks one by one
|
# If it's an async generator, send chunks one by one
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace, ABC):
|
class BaseGraphStorage(StorageNameSpace, ABC):
|
||||||
|
"""All operations related to edges in graph should be undirected."""
|
||||||
|
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -468,17 +470,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
||||||
An empty list if no matching nodes are found.
|
An empty list if no matching nodes are found.
|
||||||
"""
|
"""
|
||||||
# Default implementation iterates through all nodes, which is inefficient.
|
|
||||||
# This method should be overridden by subclasses for better performance.
|
|
||||||
all_nodes = []
|
|
||||||
all_labels = await self.get_all_labels()
|
|
||||||
for label in all_labels:
|
|
||||||
node = await self.get_node(label)
|
|
||||||
if node and "source_id" in node:
|
|
||||||
source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
|
|
||||||
if not source_ids.isdisjoint(chunk_ids):
|
|
||||||
all_nodes.append(node)
|
|
||||||
return all_nodes
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
|
@ -643,6 +634,8 @@ class DocProcessingStatus:
|
||||||
"""ISO format timestamp when document was last updated"""
|
"""ISO format timestamp when document was last updated"""
|
||||||
chunks_count: int | None = None
|
chunks_count: int | None = None
|
||||||
"""Number of chunks after splitting, used for processing"""
|
"""Number of chunks after splitting, used for processing"""
|
||||||
|
chunks_list: list[str] | None = field(default_factory=list)
|
||||||
|
"""List of chunk IDs associated with this document, used for deletion"""
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
"""Error message if failed"""
|
"""Error message if failed"""
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ consistency and makes maintenance easier.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Default values for environment variables
|
# Default values for environment variables
|
||||||
|
DEFAULT_MAX_GLEANING = 1
|
||||||
DEFAULT_MAX_TOKEN_SUMMARY = 500
|
DEFAULT_MAX_TOKEN_SUMMARY = 500
|
||||||
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
|
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
|
||||||
DEFAULT_WOKERS = 2
|
DEFAULT_WOKERS = 2
|
||||||
|
|
|
||||||
|
|
@ -26,11 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||||
"implementations": [
|
"implementations": [
|
||||||
"NanoVectorDBStorage",
|
"NanoVectorDBStorage",
|
||||||
"MilvusVectorDBStorage",
|
"MilvusVectorDBStorage",
|
||||||
"ChromaVectorDBStorage",
|
|
||||||
"PGVectorStorage",
|
"PGVectorStorage",
|
||||||
"FaissVectorDBStorage",
|
"FaissVectorDBStorage",
|
||||||
"QdrantVectorDBStorage",
|
"QdrantVectorDBStorage",
|
||||||
"MongoVectorDBStorage",
|
"MongoVectorDBStorage",
|
||||||
|
# "ChromaVectorDBStorage",
|
||||||
# "TiDBVectorDBStorage",
|
# "TiDBVectorDBStorage",
|
||||||
],
|
],
|
||||||
"required_methods": ["query", "upsert"],
|
"required_methods": ["query", "upsert"],
|
||||||
|
|
@ -38,6 +38,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||||
"DOC_STATUS_STORAGE": {
|
"DOC_STATUS_STORAGE": {
|
||||||
"implementations": [
|
"implementations": [
|
||||||
"JsonDocStatusStorage",
|
"JsonDocStatusStorage",
|
||||||
|
"RedisDocStatusStorage",
|
||||||
"PGDocStatusStorage",
|
"PGDocStatusStorage",
|
||||||
"MongoDocStatusStorage",
|
"MongoDocStatusStorage",
|
||||||
],
|
],
|
||||||
|
|
@ -81,6 +82,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||||
"MongoVectorDBStorage": [],
|
"MongoVectorDBStorage": [],
|
||||||
# Document Status Storage Implementations
|
# Document Status Storage Implementations
|
||||||
"JsonDocStatusStorage": [],
|
"JsonDocStatusStorage": [],
|
||||||
|
"RedisDocStatusStorage": ["REDIS_URI"],
|
||||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
"MongoDocStatusStorage": [],
|
"MongoDocStatusStorage": [],
|
||||||
}
|
}
|
||||||
|
|
@ -98,6 +100,7 @@ STORAGES = {
|
||||||
"MongoGraphStorage": ".kg.mongo_impl",
|
"MongoGraphStorage": ".kg.mongo_impl",
|
||||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||||
"RedisKVStorage": ".kg.redis_impl",
|
"RedisKVStorage": ".kg.redis_impl",
|
||||||
|
"RedisDocStatusStorage": ".kg.redis_impl",
|
||||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||||
# "TiDBKVStorage": ".kg.tidb_impl",
|
# "TiDBKVStorage": ".kg.tidb_impl",
|
||||||
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
ids: List of vector IDs to be deleted
|
ids: List of vector IDs to be deleted
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
|
||||||
self._collection.delete(ids=ids)
|
self._collection.delete(ids=ids)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||||
|
|
@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
################ INSERT full_doc AND chunks ################
|
################ INSERT full_doc AND chunks ################
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
|
|
@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
|
|
||||||
###### INSERT entities And relationships ######
|
###### INSERT entities And relationships ######
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
|
||||||
|
|
||||||
# Get current time as UNIX timestamp
|
# Get current time as UNIX timestamp
|
||||||
import time
|
import time
|
||||||
|
|
@ -522,11 +520,6 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
}
|
}
|
||||||
await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param)
|
await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param)
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
|
||||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
|
||||||
params = {"workspace": self.db.workspace, "status": status}
|
|
||||||
return await self.db.query(SQL, params, multirows=True)
|
|
||||||
|
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: list[str]) -> None:
|
||||||
"""Delete vectors with specified IDs from the storage.
|
"""Delete vectors with specified IDs from the storage.
|
||||||
|
|
||||||
|
|
@ -17,14 +17,13 @@ from .shared_storage import (
|
||||||
set_all_update_flags,
|
set_all_update_flags,
|
||||||
)
|
)
|
||||||
|
|
||||||
import faiss # type: ignore
|
|
||||||
|
|
||||||
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
|
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
|
||||||
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
|
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
|
||||||
|
|
||||||
if not pm.is_installed(FAISS_PACKAGE):
|
if not pm.is_installed(FAISS_PACKAGE):
|
||||||
pm.install(FAISS_PACKAGE)
|
pm.install(FAISS_PACKAGE)
|
||||||
|
|
||||||
|
import faiss # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||||
return
|
return
|
||||||
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
# Ensure chunks_list field exists for new documents
|
||||||
|
for doc_id, doc_data in data.items():
|
||||||
|
if "chunks_list" not in doc_data:
|
||||||
|
doc_data["chunks_list"] = []
|
||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await set_all_update_flags(self.namespace)
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
if need_init:
|
if need_init:
|
||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(loaded_data)
|
# Migrate legacy cache structure if needed
|
||||||
|
if self.namespace.endswith("_cache"):
|
||||||
# Calculate data count based on namespace
|
loaded_data = await self._migrate_legacy_cache_structure(
|
||||||
if self.namespace.endswith("cache"):
|
loaded_data
|
||||||
# For cache namespaces, sum the cache entries across all cache types
|
|
||||||
data_count = sum(
|
|
||||||
len(first_level_dict)
|
|
||||||
for first_level_dict in loaded_data.values()
|
|
||||||
if isinstance(first_level_dict, dict)
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# For non-cache namespaces, use the original count method
|
self._data.update(loaded_data)
|
||||||
data_count = len(loaded_data)
|
data_count = len(loaded_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||||
|
|
@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate data count based on namespace
|
# Calculate data count - all data is now flattened
|
||||||
if self.namespace.endswith("cache"):
|
data_count = len(data_dict)
|
||||||
# # For cache namespaces, sum the cache entries across all cache types
|
|
||||||
data_count = sum(
|
|
||||||
len(first_level_dict)
|
|
||||||
for first_level_dict in data_dict.values()
|
|
||||||
if isinstance(first_level_dict, dict)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For non-cache namespaces, use the original count method
|
|
||||||
data_count = len(data_dict)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||||
|
|
@ -92,22 +78,49 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
Dictionary containing all stored data
|
Dictionary containing all stored data
|
||||||
"""
|
"""
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return dict(self._data)
|
result = {}
|
||||||
|
for key, value in self._data.items():
|
||||||
|
if value:
|
||||||
|
# Create a copy to avoid modifying the original data
|
||||||
|
data = dict(value)
|
||||||
|
# Ensure time fields are present, provide default values for old data
|
||||||
|
data.setdefault("create_time", 0)
|
||||||
|
data.setdefault("update_time", 0)
|
||||||
|
result[key] = data
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return self._data.get(id)
|
result = self._data.get(id)
|
||||||
|
if result:
|
||||||
|
# Create a copy to avoid modifying the original data
|
||||||
|
result = dict(result)
|
||||||
|
# Ensure time fields are present, provide default values for old data
|
||||||
|
result.setdefault("create_time", 0)
|
||||||
|
result.setdefault("update_time", 0)
|
||||||
|
# Ensure _id field contains the clean ID
|
||||||
|
result["_id"] = id
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return [
|
results = []
|
||||||
(
|
for id in ids:
|
||||||
{k: v for k, v in self._data[id].items()}
|
data = self._data.get(id, None)
|
||||||
if self._data.get(id, None)
|
if data:
|
||||||
else None
|
# Create a copy to avoid modifying the original data
|
||||||
)
|
result = {k: v for k, v in data.items()}
|
||||||
for id in ids
|
# Ensure time fields are present, provide default values for old data
|
||||||
]
|
result.setdefault("create_time", 0)
|
||||||
|
result.setdefault("update_time", 0)
|
||||||
|
# Ensure _id field contains the clean ID
|
||||||
|
result["_id"] = id
|
||||||
|
results.append(result)
|
||||||
|
else:
|
||||||
|
results.append(None)
|
||||||
|
return results
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
|
@ -121,8 +134,29 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
"""
|
"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
current_time = int(time.time()) # Get current Unix timestamp
|
||||||
|
|
||||||
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
# Add timestamps to data based on whether key exists
|
||||||
|
for k, v in data.items():
|
||||||
|
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||||
|
if "text_chunks" in self.namespace:
|
||||||
|
if "llm_cache_list" not in v:
|
||||||
|
v["llm_cache_list"] = []
|
||||||
|
|
||||||
|
# Add timestamps based on whether key exists
|
||||||
|
if k in self._data: # Key exists, only update update_time
|
||||||
|
v["update_time"] = current_time
|
||||||
|
else: # New key, set both create_time and update_time
|
||||||
|
v["create_time"] = current_time
|
||||||
|
v["update_time"] = current_time
|
||||||
|
|
||||||
|
v["_id"] = k
|
||||||
|
|
||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await set_all_update_flags(self.namespace)
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
|
|
@ -150,14 +184,14 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
await set_all_update_flags(self.namespace)
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||||
"""Delete specific records from storage by by cache mode
|
"""Delete specific records from storage by cache mode
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
1. Changes will be persisted to disk during the next index_done_callback
|
1. Changes will be persisted to disk during the next index_done_callback
|
||||||
2. update flags to notify other processes that data persistence is needed
|
2. update flags to notify other processes that data persistence is needed
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids (list[str]): List of cache mode to be drop from storage
|
modes (list[str]): List of cache modes to be dropped from storage
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True: if the cache drop successfully
|
True: if the cache drop successfully
|
||||||
|
|
@ -167,9 +201,29 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.delete(modes)
|
async with self._storage_lock:
|
||||||
|
keys_to_delete = []
|
||||||
|
modes_set = set(modes) # Convert to set for efficient lookup
|
||||||
|
|
||||||
|
for key in list(self._data.keys()):
|
||||||
|
# Parse flattened cache key: mode:cache_type:hash
|
||||||
|
parts = key.split(":", 2)
|
||||||
|
if len(parts) == 3 and parts[0] in modes_set:
|
||||||
|
keys_to_delete.append(key)
|
||||||
|
|
||||||
|
# Batch delete
|
||||||
|
for key in keys_to_delete:
|
||||||
|
self._data.pop(key, None)
|
||||||
|
|
||||||
|
if keys_to_delete:
|
||||||
|
await set_all_update_flags(self.namespace)
|
||||||
|
logger.info(
|
||||||
|
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(f"Error dropping cache by modes: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
||||||
|
|
@ -245,9 +299,58 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
|
||||||
|
"""Migrate legacy nested cache structure to flattened structure
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Original data dictionary that may contain legacy structure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Migrated data dictionary with flattened cache keys
|
||||||
|
"""
|
||||||
|
from lightrag.utils import generate_cache_key
|
||||||
|
|
||||||
|
# Early return if data is empty
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check first entry to see if it's already in new format
|
||||||
|
first_key = next(iter(data.keys()))
|
||||||
|
if ":" in first_key and len(first_key.split(":")) == 3:
|
||||||
|
# Already in flattened format, return as-is
|
||||||
|
return data
|
||||||
|
|
||||||
|
migrated_data = {}
|
||||||
|
migration_count = 0
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
# Check if this is a legacy nested cache structure
|
||||||
|
if isinstance(value, dict) and all(
|
||||||
|
isinstance(v, dict) and "return" in v for v in value.values()
|
||||||
|
):
|
||||||
|
# This looks like a legacy cache mode with nested structure
|
||||||
|
mode = key
|
||||||
|
for cache_hash, cache_entry in value.items():
|
||||||
|
cache_type = cache_entry.get("cache_type", "extract")
|
||||||
|
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
||||||
|
migrated_data[flattened_key] = cache_entry
|
||||||
|
migration_count += 1
|
||||||
|
else:
|
||||||
|
# Keep non-cache data or already flattened cache data as-is
|
||||||
|
migrated_data[key] = value
|
||||||
|
|
||||||
|
if migration_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Migrated {migration_count} legacy cache entries to flattened structure"
|
||||||
|
)
|
||||||
|
# Persist migrated data immediately
|
||||||
|
write_json(migrated_data, self._file_name)
|
||||||
|
|
||||||
|
return migrated_data
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
"""Finalize storage resources
|
"""Finalize storage resources
|
||||||
Persistence cache data to disk before exiting
|
Persistence cache data to disk before exiting
|
||||||
"""
|
"""
|
||||||
if self.namespace.endswith("cache"):
|
if self.namespace.endswith("_cache"):
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
||||||
pm.install("pymilvus")
|
pm.install("pymilvus")
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
from pymilvus import MilvusClient # type: ignore
|
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
@ -24,16 +24,605 @@ config.read("config.ini", "utf-8")
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
@staticmethod
|
def _create_schema_for_namespace(self) -> CollectionSchema:
|
||||||
def create_collection_if_not_exist(
|
"""Create schema based on the current instance's namespace"""
|
||||||
client: MilvusClient, collection_name: str, **kwargs
|
|
||||||
):
|
# Get vector dimension from embedding_func
|
||||||
if client.has_collection(collection_name):
|
dimension = self.embedding_func.embedding_dim
|
||||||
return
|
|
||||||
client.create_collection(
|
# Base fields (common to all collections)
|
||||||
collection_name, max_length=64, id_type="string", **kwargs
|
base_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
|
||||||
|
),
|
||||||
|
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||||
|
FieldSchema(name="created_at", dtype=DataType.INT64),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Determine specific fields based on namespace
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="entity_name",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=256,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="entity_type",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=64,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG entities vector storage"
|
||||||
|
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
|
||||||
|
),
|
||||||
|
FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True),
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG relationships vector storage"
|
||||||
|
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="full_doc_id",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=64,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG chunks vector storage"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Default generic schema (backward compatibility)
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG generic vector storage"
|
||||||
|
|
||||||
|
# Merge all fields
|
||||||
|
all_fields = base_fields + specific_fields
|
||||||
|
|
||||||
|
return CollectionSchema(
|
||||||
|
fields=all_fields,
|
||||||
|
description=description,
|
||||||
|
enable_dynamic_field=True, # Support dynamic fields
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_index_params(self):
|
||||||
|
"""Get IndexParams in a version-compatible way"""
|
||||||
|
try:
|
||||||
|
# Try to use client's prepare_index_params method (most common)
|
||||||
|
if hasattr(self._client, "prepare_index_params"):
|
||||||
|
return self._client.prepare_index_params()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to import IndexParams from different possible locations
|
||||||
|
from pymilvus.client.prepare import IndexParams
|
||||||
|
|
||||||
|
return IndexParams()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pymilvus.client.types import IndexParams
|
||||||
|
|
||||||
|
return IndexParams()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pymilvus import IndexParams
|
||||||
|
|
||||||
|
return IndexParams()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If all else fails, return None to use fallback method
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_vector_index_fallback(self):
|
||||||
|
"""Fallback method to create vector index using direct API"""
|
||||||
|
try:
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
field_name="vector",
|
||||||
|
index_params={
|
||||||
|
"index_type": "HNSW",
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
"params": {"M": 16, "efConstruction": 256},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.debug("Created vector index using fallback method")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create vector index using fallback method: {e}")
|
||||||
|
|
||||||
|
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
|
||||||
|
"""Fallback method to create scalar index using direct API"""
|
||||||
|
# Skip unsupported index types
|
||||||
|
if index_type == "SORTED":
|
||||||
|
logger.info(
|
||||||
|
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
field_name=field_name,
|
||||||
|
index_params={"index_type": index_type},
|
||||||
|
)
|
||||||
|
logger.debug(f"Created {field_name} index using fallback method")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
f"Could not create {field_name} index using fallback method: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_indexes_after_collection(self):
|
||||||
|
"""Create indexes after collection is created"""
|
||||||
|
try:
|
||||||
|
# Try to get IndexParams in a version-compatible way
|
||||||
|
IndexParamsClass = self._get_index_params()
|
||||||
|
|
||||||
|
if IndexParamsClass is not None:
|
||||||
|
# Use IndexParams approach if available
|
||||||
|
try:
|
||||||
|
# Create vector index first (required for most operations)
|
||||||
|
vector_index = IndexParamsClass
|
||||||
|
vector_index.add_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_type="HNSW",
|
||||||
|
metric_type="COSINE",
|
||||||
|
params={"M": 16, "efConstruction": 256},
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=vector_index
|
||||||
|
)
|
||||||
|
logger.debug("Created vector index using IndexParams")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for vector index: {e}")
|
||||||
|
self._create_vector_index_fallback()
|
||||||
|
|
||||||
|
# Create scalar indexes based on namespace
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
# Create indexes for entity fields
|
||||||
|
try:
|
||||||
|
entity_name_index = self._get_index_params()
|
||||||
|
entity_name_index.add_index(
|
||||||
|
field_name="entity_name", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
index_params=entity_name_index,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for entity_name: {e}")
|
||||||
|
self._create_scalar_index_fallback("entity_name", "INVERTED")
|
||||||
|
|
||||||
|
try:
|
||||||
|
entity_type_index = self._get_index_params()
|
||||||
|
entity_type_index.add_index(
|
||||||
|
field_name="entity_type", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
index_params=entity_type_index,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for entity_type: {e}")
|
||||||
|
self._create_scalar_index_fallback("entity_type", "INVERTED")
|
||||||
|
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
# Create indexes for relationship fields
|
||||||
|
try:
|
||||||
|
src_id_index = self._get_index_params()
|
||||||
|
src_id_index.add_index(
|
||||||
|
field_name="src_id", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=src_id_index
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for src_id: {e}")
|
||||||
|
self._create_scalar_index_fallback("src_id", "INVERTED")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tgt_id_index = self._get_index_params()
|
||||||
|
tgt_id_index.add_index(
|
||||||
|
field_name="tgt_id", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=tgt_id_index
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for tgt_id: {e}")
|
||||||
|
self._create_scalar_index_fallback("tgt_id", "INVERTED")
|
||||||
|
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
# Create indexes for chunk fields
|
||||||
|
try:
|
||||||
|
doc_id_index = self._get_index_params()
|
||||||
|
doc_id_index.add_index(
|
||||||
|
field_name="full_doc_id", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=doc_id_index
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for full_doc_id: {e}")
|
||||||
|
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
|
||||||
|
|
||||||
|
# No common indexes needed
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fallback to direct API calls if IndexParams is not available
|
||||||
|
logger.info(
|
||||||
|
f"IndexParams not available, using fallback methods for {self.namespace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create vector index using fallback
|
||||||
|
self._create_vector_index_fallback()
|
||||||
|
|
||||||
|
# Create scalar indexes using fallback
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
self._create_scalar_index_fallback("entity_name", "INVERTED")
|
||||||
|
self._create_scalar_index_fallback("entity_type", "INVERTED")
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
self._create_scalar_index_fallback("src_id", "INVERTED")
|
||||||
|
self._create_scalar_index_fallback("tgt_id", "INVERTED")
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
|
||||||
|
|
||||||
|
logger.info(f"Created indexes for collection: {self.namespace}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
|
||||||
|
|
||||||
|
def _get_required_fields_for_namespace(self) -> dict:
|
||||||
|
"""Get required core field definitions for current namespace"""
|
||||||
|
|
||||||
|
# Base fields (common to all types)
|
||||||
|
base_fields = {
|
||||||
|
"id": {"type": "VarChar", "is_primary": True},
|
||||||
|
"vector": {"type": "FloatVector"},
|
||||||
|
"created_at": {"type": "Int64"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add specific fields based on namespace
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
specific_fields = {
|
||||||
|
"entity_name": {"type": "VarChar"},
|
||||||
|
"entity_type": {"type": "VarChar"},
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
specific_fields = {
|
||||||
|
"src_id": {"type": "VarChar"},
|
||||||
|
"tgt_id": {"type": "VarChar"},
|
||||||
|
"weight": {"type": "Double"},
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
specific_fields = {
|
||||||
|
"full_doc_id": {"type": "VarChar"},
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
specific_fields = {
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
return {**base_fields, **specific_fields}
|
||||||
|
|
||||||
|
def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
|
||||||
|
"""Check compatibility of a single field"""
|
||||||
|
field_name = existing_field.get("name", "unknown")
|
||||||
|
existing_type = existing_field.get("type")
|
||||||
|
expected_type = expected_config.get("type")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert DataType enum values to string names if needed
|
||||||
|
original_existing_type = existing_type
|
||||||
|
if hasattr(existing_type, "name"):
|
||||||
|
existing_type = existing_type.name
|
||||||
|
logger.debug(
|
||||||
|
f"Converted enum to name: {original_existing_type} -> {existing_type}"
|
||||||
|
)
|
||||||
|
elif isinstance(existing_type, int):
|
||||||
|
# Map common Milvus internal type codes to type names for backward compatibility
|
||||||
|
type_mapping = {
|
||||||
|
21: "VarChar",
|
||||||
|
101: "FloatVector",
|
||||||
|
5: "Int64",
|
||||||
|
9: "Double",
|
||||||
|
}
|
||||||
|
mapped_type = type_mapping.get(existing_type, str(existing_type))
|
||||||
|
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
|
||||||
|
existing_type = mapped_type
|
||||||
|
|
||||||
|
# Normalize type names for comparison
|
||||||
|
type_aliases = {
|
||||||
|
"VARCHAR": "VarChar",
|
||||||
|
"String": "VarChar",
|
||||||
|
"FLOAT_VECTOR": "FloatVector",
|
||||||
|
"INT64": "Int64",
|
||||||
|
"BigInt": "Int64",
|
||||||
|
"DOUBLE": "Double",
|
||||||
|
"Float": "Double",
|
||||||
|
}
|
||||||
|
|
||||||
|
original_existing = existing_type
|
||||||
|
original_expected = expected_type
|
||||||
|
existing_type = type_aliases.get(existing_type, existing_type)
|
||||||
|
expected_type = type_aliases.get(expected_type, expected_type)
|
||||||
|
|
||||||
|
if original_existing != existing_type or original_expected != expected_type:
|
||||||
|
logger.debug(
|
||||||
|
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Basic type compatibility check
|
||||||
|
type_compatible = existing_type == expected_type
|
||||||
|
logger.debug(
|
||||||
|
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not type_compatible:
|
||||||
|
logger.warning(
|
||||||
|
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Primary key check - be more flexible about primary key detection
|
||||||
|
if expected_config.get("is_primary"):
|
||||||
|
# Check multiple possible field names for primary key status
|
||||||
|
is_primary = (
|
||||||
|
existing_field.get("is_primary_key", False)
|
||||||
|
or existing_field.get("is_primary", False)
|
||||||
|
or existing_field.get("primary_key", False)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
|
||||||
|
)
|
||||||
|
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
|
||||||
|
|
||||||
|
# For ID field, be more lenient - if it's the ID field, assume it should be primary
|
||||||
|
if field_name == "id" and not is_primary:
|
||||||
|
logger.info(
|
||||||
|
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
|
||||||
|
)
|
||||||
|
# Don't fail for ID field primary key mismatch
|
||||||
|
elif not is_primary:
|
||||||
|
logger.warning(
|
||||||
|
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(f"Field '{field_name}' is compatible")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _check_vector_dimension(self, collection_info: dict):
|
||||||
|
"""Check vector dimension compatibility"""
|
||||||
|
current_dimension = self.embedding_func.embedding_dim
|
||||||
|
|
||||||
|
# Find vector field dimension
|
||||||
|
for field in collection_info.get("fields", []):
|
||||||
|
if field.get("name") == "vector":
|
||||||
|
field_type = field.get("type")
|
||||||
|
if field_type in ["FloatVector", "FLOAT_VECTOR"]:
|
||||||
|
existing_dimension = field.get("params", {}).get("dim")
|
||||||
|
|
||||||
|
if existing_dimension != current_dimension:
|
||||||
|
raise ValueError(
|
||||||
|
f"Vector dimension mismatch for collection '{self.namespace}': "
|
||||||
|
f"existing={existing_dimension}, current={current_dimension}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Vector dimension check passed: {current_dimension}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# If no vector field found, this might be an old collection created with simple schema
|
||||||
|
logger.warning(
|
||||||
|
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
|
||||||
|
)
|
||||||
|
logger.warning("Consider recreating the collection for optimal performance.")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _check_schema_compatibility(self, collection_info: dict):
|
||||||
|
"""Check schema field compatibility"""
|
||||||
|
existing_fields = {
|
||||||
|
field["name"]: field for field in collection_info.get("fields", [])
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if this is an old collection created with simple schema
|
||||||
|
has_vector_field = any(
|
||||||
|
field.get("name") == "vector" for field in collection_info.get("fields", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_vector_field:
|
||||||
|
logger.warning(
|
||||||
|
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"This collection will work but may have suboptimal performance"
|
||||||
|
)
|
||||||
|
logger.warning("Consider recreating the collection for optimal performance")
|
||||||
|
return
|
||||||
|
|
||||||
|
# For collections with vector field, check basic compatibility
|
||||||
|
# Only check for critical incompatibilities, not missing optional fields
|
||||||
|
critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
|
||||||
|
|
||||||
|
incompatible_fields = []
|
||||||
|
|
||||||
|
for field_name, expected_config in critical_fields.items():
|
||||||
|
if field_name in existing_fields:
|
||||||
|
existing_field = existing_fields[field_name]
|
||||||
|
if not self._is_field_compatible(existing_field, expected_config):
|
||||||
|
incompatible_fields.append(
|
||||||
|
f"{field_name}: expected {expected_config['type']}, "
|
||||||
|
f"got {existing_field.get('type')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if incompatible_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all expected fields for informational purposes
|
||||||
|
expected_fields = self._get_required_fields_for_namespace()
|
||||||
|
missing_fields = [
|
||||||
|
field for field in expected_fields if field not in existing_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
if missing_fields:
|
||||||
|
logger.info(
|
||||||
|
f"Collection {self.namespace} missing optional fields: {missing_fields}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"These fields would be available in a newly created collection for better performance"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Schema compatibility check passed for {self.namespace}")
|
||||||
|
|
||||||
|
def _validate_collection_compatibility(self):
|
||||||
|
"""Validate existing collection's dimension and schema compatibility"""
|
||||||
|
try:
|
||||||
|
collection_info = self._client.describe_collection(self.namespace)
|
||||||
|
|
||||||
|
# 1. Check vector dimension
|
||||||
|
self._check_vector_dimension(collection_info)
|
||||||
|
|
||||||
|
# 2. Check schema compatibility
|
||||||
|
self._check_schema_compatibility(collection_info)
|
||||||
|
|
||||||
|
logger.info(f"Collection {self.namespace} compatibility validation passed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Collection compatibility validation failed for {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_collection_if_not_exist(self):
|
||||||
|
"""Create collection if not exists and check existing collection compatibility"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First, list all collections to see what actually exists
|
||||||
|
try:
|
||||||
|
all_collections = self._client.list_collections()
|
||||||
|
logger.debug(f"All collections in database: {all_collections}")
|
||||||
|
except Exception as list_error:
|
||||||
|
logger.warning(f"Could not list collections: {list_error}")
|
||||||
|
all_collections = []
|
||||||
|
|
||||||
|
# Check if our specific collection exists
|
||||||
|
collection_exists = self._client.has_collection(self.namespace)
|
||||||
|
logger.info(
|
||||||
|
f"Collection '{self.namespace}' exists check: {collection_exists}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if collection_exists:
|
||||||
|
# Double-check by trying to describe the collection
|
||||||
|
try:
|
||||||
|
self._client.describe_collection(self.namespace)
|
||||||
|
logger.info(
|
||||||
|
f"Collection '{self.namespace}' confirmed to exist, validating compatibility..."
|
||||||
|
)
|
||||||
|
self._validate_collection_compatibility()
|
||||||
|
return
|
||||||
|
except Exception as describe_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Treating as if collection doesn't exist and creating new one..."
|
||||||
|
)
|
||||||
|
# Fall through to creation logic
|
||||||
|
|
||||||
|
# Collection doesn't exist, create new collection
|
||||||
|
logger.info(f"Creating new collection: {self.namespace}")
|
||||||
|
schema = self._create_schema_for_namespace()
|
||||||
|
|
||||||
|
# Create collection with schema only first
|
||||||
|
self._client.create_collection(
|
||||||
|
collection_name=self.namespace, schema=schema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then create indexes
|
||||||
|
self._create_indexes_after_collection()
|
||||||
|
|
||||||
|
logger.info(f"Successfully created Milvus collection: {self.namespace}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there's any error, try to force create the collection
|
||||||
|
logger.info(f"Attempting to force create collection {self.namespace}...")
|
||||||
|
try:
|
||||||
|
# Try to drop the collection first if it exists in a bad state
|
||||||
|
try:
|
||||||
|
if self._client.has_collection(self.namespace):
|
||||||
|
logger.info(
|
||||||
|
f"Dropping potentially corrupted collection {self.namespace}"
|
||||||
|
)
|
||||||
|
self._client.drop_collection(self.namespace)
|
||||||
|
except Exception as drop_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not drop collection {self.namespace}: {drop_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fresh collection
|
||||||
|
schema = self._create_schema_for_namespace()
|
||||||
|
self._client.create_collection(
|
||||||
|
collection_name=self.namespace, schema=schema
|
||||||
|
)
|
||||||
|
self._create_indexes_after_collection()
|
||||||
|
logger.info(f"Successfully force-created collection {self.namespace}")
|
||||||
|
|
||||||
|
except Exception as create_error:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to force-create collection {self.namespace}: {create_error}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
|
|
@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
|
# Ensure created_at is in meta_fields
|
||||||
|
if "created_at" not in self.meta_fields:
|
||||||
|
self.meta_fields.add("created_at")
|
||||||
|
|
||||||
self._client = MilvusClient(
|
self._client = MilvusClient(
|
||||||
uri=os.environ.get(
|
uri=os.environ.get(
|
||||||
"MILVUS_URI",
|
"MILVUS_URI",
|
||||||
|
|
@ -68,14 +661,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
|
||||||
self._client,
|
# Create collection and check compatibility
|
||||||
self.namespace,
|
self._create_collection_if_not_exist()
|
||||||
dimension=self.embedding_func.embedding_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
embedding = await self.embedding_func(
|
embedding = await self.embedding_func(
|
||||||
[query], _priority=5
|
[query], _priority=5
|
||||||
) # higher priority for query
|
) # higher priority for query
|
||||||
|
|
||||||
|
# Include all meta_fields (created_at is now always included)
|
||||||
|
output_fields = list(self.meta_fields)
|
||||||
|
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
data=embedding,
|
data=embedding,
|
||||||
limit=top_k,
|
limit=top_k,
|
||||||
output_fields=list(self.meta_fields) + ["created_at"],
|
output_fields=output_fields,
|
||||||
search_params={
|
search_params={
|
||||||
"metric_type": "COSINE",
|
"metric_type": "COSINE",
|
||||||
"params": {"radius": self.cosine_better_than_threshold},
|
"params": {"radius": self.cosine_better_than_threshold},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(results)
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
**dp["entity"],
|
**dp["entity"],
|
||||||
"id": dp["id"],
|
"id": dp["id"],
|
||||||
"distance": dp["distance"],
|
"distance": dp["distance"],
|
||||||
# created_at is requested in output_fields, so it should be a top-level key in the result dict (dp)
|
|
||||||
"created_at": dp.get("created_at"),
|
"created_at": dp.get("created_at"),
|
||||||
}
|
}
|
||||||
for dp in results[0]
|
for dp in results[0]
|
||||||
|
|
@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
The vector data if found, or None if not found
|
The vector data if found, or None if not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Include all meta_fields (created_at is now always included) plus id
|
||||||
|
output_fields = list(self.meta_fields) + ["id"]
|
||||||
|
|
||||||
# Query Milvus for a specific ID
|
# Query Milvus for a specific ID
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
filter=f'id == "{id}"',
|
filter=f'id == "{id}"',
|
||||||
output_fields=list(self.meta_fields) + ["id", "created_at"],
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result or len(result) == 0:
|
if not result or len(result) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Ensure the result contains created_at field
|
|
||||||
if "created_at" not in result[0]:
|
|
||||||
result[0]["created_at"] = None
|
|
||||||
|
|
||||||
return result[0]
|
return result[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||||
|
|
@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Include all meta_fields (created_at is now always included) plus id
|
||||||
|
output_fields = list(self.meta_fields) + ["id"]
|
||||||
|
|
||||||
# Prepare the ID filter expression
|
# Prepare the ID filter expression
|
||||||
id_list = '", "'.join(ids)
|
id_list = '", "'.join(ids)
|
||||||
filter_expr = f'id in ["{id_list}"]'
|
filter_expr = f'id in ["{id_list}"]'
|
||||||
|
|
@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
filter=filter_expr,
|
filter=filter_expr,
|
||||||
output_fields=list(self.meta_fields) + ["id", "created_at"],
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure each result contains created_at field
|
|
||||||
for item in result:
|
|
||||||
if "created_at" not in item:
|
|
||||||
item["created_at"] = None
|
|
||||||
|
|
||||||
return result or []
|
return result or []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||||
|
|
@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
self._client.drop_collection(self.namespace)
|
self._client.drop_collection(self.namespace)
|
||||||
|
|
||||||
# Recreate the collection
|
# Recreate the collection
|
||||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
self._create_collection_if_not_exist()
|
||||||
self._client,
|
|
||||||
self.namespace,
|
|
||||||
dimension=self.embedding_func.embedding_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
|
@ -14,7 +15,6 @@ from ..base import (
|
||||||
DocStatus,
|
DocStatus,
|
||||||
DocStatusStorage,
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
|
||||||
from ..utils import logger, compute_mdhash_id
|
from ..utils import logger, compute_mdhash_id
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
|
|
@ -35,6 +35,7 @@ config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||||
|
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
|
||||||
|
|
||||||
|
|
||||||
class ClientManager:
|
class ClientManager:
|
||||||
|
|
@ -96,11 +97,22 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
self._data = None
|
self._data = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
return await self._data.find_one({"_id": id})
|
# Unified handling for flattened keys
|
||||||
|
doc = await self._data.find_one({"_id": id})
|
||||||
|
if doc:
|
||||||
|
# Ensure time fields are present, provide default values for old data
|
||||||
|
doc.setdefault("create_time", 0)
|
||||||
|
doc.setdefault("update_time", 0)
|
||||||
|
return doc
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
cursor = self._data.find({"_id": {"$in": ids}})
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
return await cursor.to_list()
|
docs = await cursor.to_list()
|
||||||
|
# Ensure time fields are present for all documents
|
||||||
|
for doc in docs:
|
||||||
|
doc.setdefault("create_time", 0)
|
||||||
|
doc.setdefault("update_time", 0)
|
||||||
|
return docs
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
|
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
|
||||||
|
|
@ -117,47 +129,53 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
result = {}
|
result = {}
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
doc_id = doc.pop("_id")
|
doc_id = doc.pop("_id")
|
||||||
|
# Ensure time fields are present for all documents
|
||||||
|
doc.setdefault("create_time", 0)
|
||||||
|
doc.setdefault("update_time", 0)
|
||||||
result[doc_id] = doc
|
result[doc_id] = doc
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
# Unified handling for all namespaces with flattened keys
|
||||||
update_tasks: list[Any] = []
|
# Use bulk_write for better performance
|
||||||
for mode, items in data.items():
|
from pymongo import UpdateOne
|
||||||
for k, v in items.items():
|
|
||||||
key = f"{mode}_{k}"
|
|
||||||
data[mode][k]["_id"] = f"{mode}_{k}"
|
|
||||||
update_tasks.append(
|
|
||||||
self._data.update_one(
|
|
||||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await asyncio.gather(*update_tasks)
|
|
||||||
else:
|
|
||||||
update_tasks = []
|
|
||||||
for k, v in data.items():
|
|
||||||
data[k]["_id"] = k
|
|
||||||
update_tasks.append(
|
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
|
||||||
)
|
|
||||||
await asyncio.gather(*update_tasks)
|
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
operations = []
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
current_time = int(time.time()) # Get current Unix timestamp
|
||||||
res = {}
|
|
||||||
v = await self._data.find_one({"_id": mode + "_" + id})
|
for k, v in data.items():
|
||||||
if v:
|
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||||
res[id] = v
|
if self.namespace.endswith("text_chunks"):
|
||||||
logger.debug(f"llm_response_cache find one by:{id}")
|
if "llm_cache_list" not in v:
|
||||||
return res
|
v["llm_cache_list"] = []
|
||||||
else:
|
|
||||||
return None
|
# Create a copy of v for $set operation, excluding create_time to avoid conflicts
|
||||||
else:
|
v_for_set = v.copy()
|
||||||
return None
|
v_for_set["_id"] = k # Use flattened key as _id
|
||||||
|
v_for_set["update_time"] = current_time # Always update update_time
|
||||||
|
|
||||||
|
# Remove create_time from $set to avoid conflict with $setOnInsert
|
||||||
|
v_for_set.pop("create_time", None)
|
||||||
|
|
||||||
|
operations.append(
|
||||||
|
UpdateOne(
|
||||||
|
{"_id": k},
|
||||||
|
{
|
||||||
|
"$set": v_for_set, # Update all fields except create_time
|
||||||
|
"$setOnInsert": {
|
||||||
|
"create_time": current_time
|
||||||
|
}, # Set create_time only on insert
|
||||||
|
},
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if operations:
|
||||||
|
await self._data.bulk_write(operations)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# Mongo handles persistence automatically
|
# Mongo handles persistence automatically
|
||||||
|
|
@ -197,8 +215,8 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build regex pattern to match documents with the specified modes
|
# Build regex pattern to match flattened key format: mode:cache_type:hash
|
||||||
pattern = f"^({'|'.join(modes)})_"
|
pattern = f"^({'|'.join(modes)}):"
|
||||||
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
||||||
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
||||||
return True
|
return True
|
||||||
|
|
@ -262,11 +280,14 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
return data - existing_ids
|
return data - existing_ids
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
update_tasks: list[Any] = []
|
update_tasks: list[Any] = []
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
# Ensure chunks_list field exists and is an array
|
||||||
|
if "chunks_list" not in v:
|
||||||
|
v["chunks_list"] = []
|
||||||
data[k]["_id"] = k
|
data[k]["_id"] = k
|
||||||
update_tasks.append(
|
update_tasks.append(
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||||
|
|
@ -299,6 +320,7 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
updated_at=doc.get("updated_at"),
|
updated_at=doc.get("updated_at"),
|
||||||
chunks_count=doc.get("chunks_count", -1),
|
chunks_count=doc.get("chunks_count", -1),
|
||||||
file_path=doc.get("file_path", doc["_id"]),
|
file_path=doc.get("file_path", doc["_id"]),
|
||||||
|
chunks_list=doc.get("chunks_list", []),
|
||||||
)
|
)
|
||||||
for doc in result
|
for doc in result
|
||||||
}
|
}
|
||||||
|
|
@ -417,11 +439,21 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
Check if there's a direct single-hop edge between source_node_id and target_node_id.
|
||||||
"""
|
"""
|
||||||
# Direct check if the target_node appears among the edges array.
|
|
||||||
doc = await self.edge_collection.find_one(
|
doc = await self.edge_collection.find_one(
|
||||||
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
{
|
||||||
|
"$or": [
|
||||||
|
{
|
||||||
|
"source_node_id": source_node_id,
|
||||||
|
"target_node_id": target_node_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_node_id": target_node_id,
|
||||||
|
"target_node_id": source_node_id,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
{"_id": 1},
|
{"_id": 1},
|
||||||
)
|
)
|
||||||
return doc is not None
|
return doc is not None
|
||||||
|
|
@ -651,7 +683,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
Upsert an edge between source_node_id and target_node_id with optional 'relation'.
|
||||||
If an edge with the same target exists, we remove it and re-insert with updated data.
|
If an edge with the same target exists, we remove it and re-insert with updated data.
|
||||||
"""
|
"""
|
||||||
# Ensure source node exists
|
# Ensure source node exists
|
||||||
|
|
@ -663,8 +695,22 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
GRAPH_FIELD_SEP
|
GRAPH_FIELD_SEP
|
||||||
)
|
)
|
||||||
|
|
||||||
|
edge_data["source_node_id"] = source_node_id
|
||||||
|
edge_data["target_node_id"] = target_node_id
|
||||||
|
|
||||||
await self.edge_collection.update_one(
|
await self.edge_collection.update_one(
|
||||||
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
{
|
||||||
|
"$or": [
|
||||||
|
{
|
||||||
|
"source_node_id": source_node_id,
|
||||||
|
"target_node_id": target_node_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_node_id": target_node_id,
|
||||||
|
"target_node_id": source_node_id,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
update_doc,
|
update_doc,
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
@ -678,7 +724,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
1) Remove node's doc entirely.
|
1) Remove node's doc entirely.
|
||||||
2) Remove inbound edges from any doc that references node_id.
|
2) Remove inbound & outbound edges from any doc that references node_id.
|
||||||
"""
|
"""
|
||||||
# Remove all edges
|
# Remove all edges
|
||||||
await self.edge_collection.delete_many(
|
await self.edge_collection.delete_many(
|
||||||
|
|
@ -709,141 +755,369 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
labels.append(doc["_id"])
|
labels.append(doc["_id"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
def _construct_graph_node(
|
||||||
|
self, node_id, node_data: dict[str, str]
|
||||||
|
) -> KnowledgeGraphNode:
|
||||||
|
return KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node_id],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in node_data.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"connected_edges",
|
||||||
|
"source_ids",
|
||||||
|
"edge_count",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]):
|
||||||
|
return KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relationship", ""),
|
||||||
|
source=edge["source_node_id"],
|
||||||
|
target=edge["target_node_id"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"source_node_id",
|
||||||
|
"target_node_id",
|
||||||
|
"relationship",
|
||||||
|
"source_ids",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_knowledge_graph_all_by_degree(
|
||||||
|
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""
|
||||||
|
It's possible that the node with one or multiple relationships is retrieved,
|
||||||
|
while its neighbor is not. Then this node might seem like disconnected in UI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_node_count = await self.collection.count_documents({})
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
seen_edges = set()
|
||||||
|
|
||||||
|
result.is_truncated = total_node_count > max_nodes
|
||||||
|
if result.is_truncated:
|
||||||
|
# Get all node_ids ranked by degree if max_nodes exceeds total node count
|
||||||
|
pipeline = [
|
||||||
|
{"$project": {"source_node_id": 1, "_id": 0}},
|
||||||
|
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
|
||||||
|
{
|
||||||
|
"$unionWith": {
|
||||||
|
"coll": self._edge_collection_name,
|
||||||
|
"pipeline": [
|
||||||
|
{"$project": {"target_node_id": 1, "_id": 0}},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": "$target_node_id",
|
||||||
|
"degree": {"$sum": 1},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
|
||||||
|
{"$sort": {"degree": -1}},
|
||||||
|
{"$limit": max_nodes},
|
||||||
|
]
|
||||||
|
cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
|
||||||
|
|
||||||
|
node_ids = []
|
||||||
|
async for doc in cursor:
|
||||||
|
node_id = str(doc["_id"])
|
||||||
|
node_ids.append(node_id)
|
||||||
|
|
||||||
|
cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
|
||||||
|
async for doc in cursor:
|
||||||
|
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
|
||||||
|
|
||||||
|
# As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
|
||||||
|
edge_cursor = self.edge_collection.find(
|
||||||
|
{
|
||||||
|
"$and": [
|
||||||
|
{"source_node_id": {"$in": node_ids}},
|
||||||
|
{"target_node_id": {"$in": node_ids}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# All nodes and edges are needed
|
||||||
|
cursor = self.collection.find({}, {"source_ids": 0})
|
||||||
|
|
||||||
|
async for doc in cursor:
|
||||||
|
node_id = str(doc["_id"])
|
||||||
|
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
|
||||||
|
|
||||||
|
edge_cursor = self.edge_collection.find({})
|
||||||
|
|
||||||
|
async for edge in edge_cursor:
|
||||||
|
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _bidirectional_bfs_nodes(
|
||||||
|
self,
|
||||||
|
node_labels: list[str],
|
||||||
|
seen_nodes: set[str],
|
||||||
|
result: KnowledgeGraph,
|
||||||
|
depth: int = 0,
|
||||||
|
max_depth: int = 3,
|
||||||
|
max_nodes: int = MAX_GRAPH_NODES,
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
if depth > max_depth or len(result.nodes) > max_nodes:
|
||||||
|
return result
|
||||||
|
|
||||||
|
cursor = self.collection.find({"_id": {"$in": node_labels}})
|
||||||
|
|
||||||
|
async for node in cursor:
|
||||||
|
node_id = node["_id"]
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
result.nodes.append(self._construct_graph_node(node_id, node))
|
||||||
|
if len(result.nodes) > max_nodes:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Collect neighbors
|
||||||
|
# Get both inbound and outbound one hop nodes
|
||||||
|
cursor = self.edge_collection.find(
|
||||||
|
{
|
||||||
|
"$or": [
|
||||||
|
{"source_node_id": {"$in": node_labels}},
|
||||||
|
{"target_node_id": {"$in": node_labels}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
neighbor_nodes = []
|
||||||
|
async for edge in cursor:
|
||||||
|
if edge["source_node_id"] not in seen_nodes:
|
||||||
|
neighbor_nodes.append(edge["source_node_id"])
|
||||||
|
if edge["target_node_id"] not in seen_nodes:
|
||||||
|
neighbor_nodes.append(edge["target_node_id"])
|
||||||
|
|
||||||
|
if neighbor_nodes:
|
||||||
|
result = await self._bidirectional_bfs_nodes(
|
||||||
|
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_knowledge_subgraph_bidirectional_bfs(
|
||||||
|
self,
|
||||||
|
node_label: str,
|
||||||
|
depth=0,
|
||||||
|
max_depth: int = 3,
|
||||||
|
max_nodes: int = MAX_GRAPH_NODES,
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
|
||||||
|
result = await self._bidirectional_bfs_nodes(
|
||||||
|
[node_label], seen_nodes, result, depth, max_depth, max_nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all edges from seen_nodes
|
||||||
|
all_node_ids = list(seen_nodes)
|
||||||
|
cursor = self.edge_collection.find(
|
||||||
|
{
|
||||||
|
"$and": [
|
||||||
|
{"source_node_id": {"$in": all_node_ids}},
|
||||||
|
{"target_node_id": {"$in": all_node_ids}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async for edge in cursor:
|
||||||
|
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_knowledge_subgraph_in_out_bound_bfs(
|
||||||
|
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
project_doc = {
|
||||||
|
"source_ids": 0,
|
||||||
|
"created_at": 0,
|
||||||
|
"entity_type": 0,
|
||||||
|
"file_path": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Verify if starting node exists
|
||||||
|
start_node = await self.collection.find_one({"_id": node_label})
|
||||||
|
if not start_node:
|
||||||
|
logger.warning(f"Starting node with label {node_label} does not exist!")
|
||||||
|
return result
|
||||||
|
|
||||||
|
seen_nodes.add(node_label)
|
||||||
|
result.nodes.append(self._construct_graph_node(node_label, start_node))
|
||||||
|
|
||||||
|
if max_depth == 0:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# In MongoDB, depth = 0 means one-hop
|
||||||
|
max_depth = max_depth - 1
|
||||||
|
|
||||||
|
pipeline = [
|
||||||
|
{"$match": {"_id": node_label}},
|
||||||
|
{"$project": project_doc},
|
||||||
|
{
|
||||||
|
"$graphLookup": {
|
||||||
|
"from": self._edge_collection_name,
|
||||||
|
"startWith": "$_id",
|
||||||
|
"connectFromField": "target_node_id",
|
||||||
|
"connectToField": "source_node_id",
|
||||||
|
"maxDepth": max_depth,
|
||||||
|
"depthField": "depth",
|
||||||
|
"as": "connected_edges",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$unionWith": {
|
||||||
|
"coll": self._collection_name,
|
||||||
|
"pipeline": [
|
||||||
|
{"$match": {"_id": node_label}},
|
||||||
|
{"$project": project_doc},
|
||||||
|
{
|
||||||
|
"$graphLookup": {
|
||||||
|
"from": self._edge_collection_name,
|
||||||
|
"startWith": "$_id",
|
||||||
|
"connectFromField": "source_node_id",
|
||||||
|
"connectToField": "target_node_id",
|
||||||
|
"maxDepth": max_depth,
|
||||||
|
"depthField": "depth",
|
||||||
|
"as": "connected_edges",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||||
|
node_edges = []
|
||||||
|
|
||||||
|
# Two records for node_label are returned capturing outbound and inbound connected_edges
|
||||||
|
async for doc in cursor:
|
||||||
|
if doc.get("connected_edges", []):
|
||||||
|
node_edges.extend(doc.get("connected_edges"))
|
||||||
|
|
||||||
|
# Sort the connected edges by depth ascending and weight descending
|
||||||
|
# And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
|
||||||
|
node_edges = sorted(
|
||||||
|
node_edges,
|
||||||
|
key=lambda x: (x["depth"], -x["weight"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# As order matters, we need to use another list to store the node_id
|
||||||
|
# And only take the first max_nodes ones
|
||||||
|
node_ids = []
|
||||||
|
for edge in node_edges:
|
||||||
|
if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
|
||||||
|
node_ids.append(edge["source_node_id"])
|
||||||
|
seen_nodes.add(edge["source_node_id"])
|
||||||
|
|
||||||
|
if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
|
||||||
|
node_ids.append(edge["target_node_id"])
|
||||||
|
seen_nodes.add(edge["target_node_id"])
|
||||||
|
|
||||||
|
# Filter out all the node whose id is same as node_label so that we do not check existence next step
|
||||||
|
cursor = self.collection.find({"_id": {"$in": node_ids}})
|
||||||
|
|
||||||
|
async for doc in cursor:
|
||||||
|
result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
|
||||||
|
|
||||||
|
for edge in node_edges:
|
||||||
|
if (
|
||||||
|
edge["source_node_id"] not in seen_nodes
|
||||||
|
or edge["target_node_id"] not in seen_nodes
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
max_depth: int = 5,
|
max_depth: int = 3,
|
||||||
max_nodes: int = MAX_GRAPH_NODES,
|
max_nodes: int = MAX_GRAPH_NODES,
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
"""
|
"""
|
||||||
Get complete connected subgraph for specified node (including the starting node itself)
|
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_label: Label of the nodes to start from
|
node_label: Label of the starting node, * means all nodes
|
||||||
max_depth: Maximum depth of traversal (default: 5)
|
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||||
|
max_nodes: Maxiumu nodes to return, Defaults to 1000
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraph object containing nodes and edges of the subgraph
|
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||||
|
indicating whether the graph was truncated due to max_nodes limit
|
||||||
|
|
||||||
|
If a graph is like this and starting from B:
|
||||||
|
A → B ← C ← F, B -> E, C → D
|
||||||
|
|
||||||
|
Outbound BFS:
|
||||||
|
B → E
|
||||||
|
|
||||||
|
Inbound BFS:
|
||||||
|
A → B
|
||||||
|
C → B
|
||||||
|
F → C
|
||||||
|
|
||||||
|
Bidirectional BFS:
|
||||||
|
A → B
|
||||||
|
B → E
|
||||||
|
F → C
|
||||||
|
C → B
|
||||||
|
C → D
|
||||||
"""
|
"""
|
||||||
label = node_label
|
|
||||||
result = KnowledgeGraph()
|
result = KnowledgeGraph()
|
||||||
seen_nodes = set()
|
start = time.perf_counter()
|
||||||
seen_edges = set()
|
|
||||||
node_edges = []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Optimize pipeline to avoid memory issues with large datasets
|
# Optimize pipeline to avoid memory issues with large datasets
|
||||||
if label == "*":
|
if node_label == "*":
|
||||||
# For getting all nodes, use a simpler pipeline to avoid memory issues
|
result = await self.get_knowledge_graph_all_by_degree(
|
||||||
pipeline = [
|
max_depth, max_nodes
|
||||||
{"$limit": max_nodes}, # Limit early to reduce memory usage
|
)
|
||||||
{
|
elif GRAPH_BFS_MODE == "in_out_bound":
|
||||||
"$graphLookup": {
|
result = await self.get_knowledge_subgraph_in_out_bound_bfs(
|
||||||
"from": self._edge_collection_name,
|
node_label, max_depth, max_nodes
|
||||||
"startWith": "$_id",
|
)
|
||||||
"connectFromField": "target_node_id",
|
else:
|
||||||
"connectToField": "source_node_id",
|
result = await self.get_knowledge_subgraph_bidirectional_bfs(
|
||||||
"maxDepth": max_depth,
|
node_label, 0, max_depth, max_nodes
|
||||||
"depthField": "depth",
|
|
||||||
"as": "connected_edges",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if we need to set truncation flag
|
|
||||||
all_node_count = await self.collection.count_documents({})
|
|
||||||
result.is_truncated = all_node_count > max_nodes
|
|
||||||
else:
|
|
||||||
# Verify if starting node exists
|
|
||||||
start_node = await self.collection.find_one({"_id": label})
|
|
||||||
if not start_node:
|
|
||||||
logger.warning(f"Starting node with label {label} does not exist!")
|
|
||||||
return result
|
|
||||||
|
|
||||||
# For specific node queries, use the original pipeline but optimized
|
|
||||||
pipeline = [
|
|
||||||
{"$match": {"_id": label}},
|
|
||||||
{
|
|
||||||
"$graphLookup": {
|
|
||||||
"from": self._edge_collection_name,
|
|
||||||
"startWith": "$_id",
|
|
||||||
"connectFromField": "target_node_id",
|
|
||||||
"connectToField": "source_node_id",
|
|
||||||
"maxDepth": max_depth,
|
|
||||||
"depthField": "depth",
|
|
||||||
"as": "connected_edges",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
|
||||||
{"$sort": {"edge_count": -1}},
|
|
||||||
{"$limit": max_nodes},
|
|
||||||
]
|
|
||||||
|
|
||||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
|
||||||
nodes_processed = 0
|
|
||||||
|
|
||||||
async for doc in cursor:
|
|
||||||
# Add the start node
|
|
||||||
node_id = str(doc["_id"])
|
|
||||||
result.nodes.append(
|
|
||||||
KnowledgeGraphNode(
|
|
||||||
id=node_id,
|
|
||||||
labels=[node_id],
|
|
||||||
properties={
|
|
||||||
k: v
|
|
||||||
for k, v in doc.items()
|
|
||||||
if k
|
|
||||||
not in [
|
|
||||||
"_id",
|
|
||||||
"connected_edges",
|
|
||||||
"edge_count",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
seen_nodes.add(node_id)
|
|
||||||
if doc.get("connected_edges", []):
|
|
||||||
node_edges.extend(doc.get("connected_edges"))
|
|
||||||
|
|
||||||
nodes_processed += 1
|
duration = time.perf_counter() - start
|
||||||
|
|
||||||
# Additional safety check to prevent memory issues
|
|
||||||
if nodes_processed >= max_nodes:
|
|
||||||
result.is_truncated = True
|
|
||||||
break
|
|
||||||
|
|
||||||
for edge in node_edges:
|
|
||||||
if (
|
|
||||||
edge["source_node_id"] not in seen_nodes
|
|
||||||
or edge["target_node_id"] not in seen_nodes
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
|
||||||
if edge_id not in seen_edges:
|
|
||||||
result.edges.append(
|
|
||||||
KnowledgeGraphEdge(
|
|
||||||
id=edge_id,
|
|
||||||
type=edge.get("relationship", ""),
|
|
||||||
source=edge["source_node_id"],
|
|
||||||
target=edge["target_node_id"],
|
|
||||||
properties={
|
|
||||||
k: v
|
|
||||||
for k, v in edge.items()
|
|
||||||
if k
|
|
||||||
not in [
|
|
||||||
"_id",
|
|
||||||
"source_node_id",
|
|
||||||
"target_node_id",
|
|
||||||
"relationship",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_edges.add(edge_id)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
|
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
|
|
@ -856,13 +1130,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
try:
|
try:
|
||||||
simple_cursor = self.collection.find({}).limit(max_nodes)
|
simple_cursor = self.collection.find({}).limit(max_nodes)
|
||||||
async for doc in simple_cursor:
|
async for doc in simple_cursor:
|
||||||
node_id = str(doc["_id"])
|
|
||||||
result.nodes.append(
|
result.nodes.append(
|
||||||
KnowledgeGraphNode(
|
self._construct_graph_node(str(doc["_id"]), doc)
|
||||||
id=node_id,
|
|
||||||
labels=[node_id],
|
|
||||||
properties={k: v for k, v in doc.items() if k != "_id"},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
result.is_truncated = True
|
result.is_truncated = True
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -1023,13 +1292,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
logger.debug("vector index already exist")
|
logger.debug("vector index already exist")
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Add current time as Unix timestamp
|
# Add current time as Unix timestamp
|
||||||
import time
|
|
||||||
|
|
||||||
current_time = int(time.time())
|
current_time = int(time.time())
|
||||||
|
|
||||||
list_data = [
|
list_data = [
|
||||||
|
|
@ -1114,7 +1381,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
Args:
|
Args:
|
||||||
ids: List of vector IDs to be deleted
|
ids: List of vector IDs to be deleted
|
||||||
"""
|
"""
|
||||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||||
if not ids:
|
if not ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,9 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
graph = await self._get_graph()
|
graph = await self._get_graph()
|
||||||
return graph.degree(src_id) + graph.degree(tgt_id)
|
src_degree = graph.degree(src_id) if graph.has_node(src_id) else 0
|
||||||
|
tgt_degree = graph.degree(tgt_id) if graph.has_node(tgt_id) else 0
|
||||||
|
return src_degree + tgt_degree
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,52 @@ class PostgreSQLDB:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}")
|
logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}")
|
||||||
|
|
||||||
|
async def _migrate_llm_cache_add_cache_type(self):
|
||||||
|
"""Add cache_type column to LIGHTRAG_LLM_CACHE table if it doesn't exist"""
|
||||||
|
try:
|
||||||
|
# Check if cache_type column exists
|
||||||
|
check_column_sql = """
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = 'lightrag_llm_cache'
|
||||||
|
AND column_name = 'cache_type'
|
||||||
|
"""
|
||||||
|
|
||||||
|
column_info = await self.query(check_column_sql)
|
||||||
|
if not column_info:
|
||||||
|
logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table")
|
||||||
|
add_column_sql = """
|
||||||
|
ALTER TABLE LIGHTRAG_LLM_CACHE
|
||||||
|
ADD COLUMN cache_type VARCHAR(32) NULL
|
||||||
|
"""
|
||||||
|
await self.execute(add_column_sql)
|
||||||
|
logger.info(
|
||||||
|
"Successfully added cache_type column to LIGHTRAG_LLM_CACHE table"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Migrate existing data: extract cache_type from flattened keys
|
||||||
|
logger.info(
|
||||||
|
"Migrating existing LLM cache data to populate cache_type field"
|
||||||
|
)
|
||||||
|
update_sql = """
|
||||||
|
UPDATE LIGHTRAG_LLM_CACHE
|
||||||
|
SET cache_type = CASE
|
||||||
|
WHEN id LIKE '%:%:%' THEN split_part(id, ':', 2)
|
||||||
|
ELSE 'extract'
|
||||||
|
END
|
||||||
|
WHERE cache_type IS NULL
|
||||||
|
"""
|
||||||
|
await self.execute(update_sql)
|
||||||
|
logger.info("Successfully migrated existing LLM cache data")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"cache_type column already exists in LIGHTRAG_LLM_CACHE table"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to add cache_type column to LIGHTRAG_LLM_CACHE: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _migrate_timestamp_columns(self):
|
async def _migrate_timestamp_columns(self):
|
||||||
"""Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time"""
|
"""Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time"""
|
||||||
# Tables and columns that need migration
|
# Tables and columns that need migration
|
||||||
|
|
@ -189,6 +235,239 @@ class PostgreSQLDB:
|
||||||
# Log error but don't interrupt the process
|
# Log error but don't interrupt the process
|
||||||
logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}")
|
logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}")
|
||||||
|
|
||||||
|
async def _migrate_doc_chunks_to_vdb_chunks(self):
|
||||||
|
"""
|
||||||
|
Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met.
|
||||||
|
This migration is intended for users who are upgrading and have an older table structure
|
||||||
|
where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty
|
||||||
|
vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS"
|
||||||
|
vdb_chunks_count_result = await self.query(vdb_chunks_count_sql)
|
||||||
|
if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0:
|
||||||
|
logger.info(
|
||||||
|
"Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Check if `content_vector` column exists in the old table
|
||||||
|
check_column_sql = """
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector'
|
||||||
|
"""
|
||||||
|
column_exists = await self.query(check_column_sql)
|
||||||
|
if not column_exists:
|
||||||
|
logger.info(
|
||||||
|
"Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data
|
||||||
|
doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS"
|
||||||
|
doc_chunks_count_result = await self.query(doc_chunks_count_sql)
|
||||||
|
if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0:
|
||||||
|
logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. Perform the migration
|
||||||
|
logger.info(
|
||||||
|
"Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..."
|
||||||
|
)
|
||||||
|
migration_sql = """
|
||||||
|
INSERT INTO LIGHTRAG_VDB_CHUNKS (
|
||||||
|
id, workspace, full_doc_id, chunk_order_index, tokens, content,
|
||||||
|
content_vector, file_path, create_time, update_time
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
id, workspace, full_doc_id, chunk_order_index, tokens, content,
|
||||||
|
content_vector, file_path, create_time, update_time
|
||||||
|
FROM LIGHTRAG_DOC_CHUNKS
|
||||||
|
ON CONFLICT (workspace, id) DO NOTHING;
|
||||||
|
"""
|
||||||
|
await self.execute(migration_sql)
|
||||||
|
logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
|
||||||
|
# Do not re-raise, to allow the application to start
|
||||||
|
|
||||||
|
async def _check_llm_cache_needs_migration(self):
|
||||||
|
"""Check if LLM cache data needs migration by examining the first record"""
|
||||||
|
try:
|
||||||
|
# Only query the first record to determine format
|
||||||
|
check_sql = """
|
||||||
|
SELECT id FROM LIGHTRAG_LLM_CACHE
|
||||||
|
ORDER BY create_time ASC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
result = await self.query(check_sql)
|
||||||
|
|
||||||
|
if result and result.get("id"):
|
||||||
|
# If id doesn't contain colon, it's old format
|
||||||
|
return ":" not in result["id"]
|
||||||
|
|
||||||
|
return False # No data or already new format
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check LLM cache migration status: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _migrate_llm_cache_to_flattened_keys(self):
|
||||||
|
"""Migrate LLM cache to flattened key format, recalculating hash values"""
|
||||||
|
try:
|
||||||
|
# Get all old format data
|
||||||
|
old_data_sql = """
|
||||||
|
SELECT id, mode, original_prompt, return_value, chunk_id,
|
||||||
|
create_time, update_time
|
||||||
|
FROM LIGHTRAG_LLM_CACHE
|
||||||
|
WHERE id NOT LIKE '%:%'
|
||||||
|
"""
|
||||||
|
|
||||||
|
old_records = await self.query(old_data_sql, multirows=True)
|
||||||
|
|
||||||
|
if not old_records:
|
||||||
|
logger.info("No old format LLM cache data found, skipping migration")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(old_records)} old format cache records, starting migration..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import hash calculation function
|
||||||
|
from ..utils import compute_args_hash
|
||||||
|
|
||||||
|
migrated_count = 0
|
||||||
|
|
||||||
|
# Migrate data in batches
|
||||||
|
for record in old_records:
|
||||||
|
try:
|
||||||
|
# Recalculate hash using correct method
|
||||||
|
new_hash = compute_args_hash(
|
||||||
|
record["mode"], record["original_prompt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine cache_type based on mode
|
||||||
|
cache_type = "extract" if record["mode"] == "default" else "unknown"
|
||||||
|
|
||||||
|
# Generate new flattened key
|
||||||
|
new_key = f"{record['mode']}:{cache_type}:{new_hash}"
|
||||||
|
|
||||||
|
# Insert new format data with cache_type field
|
||||||
|
insert_sql = """
|
||||||
|
INSERT INTO LIGHTRAG_LLM_CACHE
|
||||||
|
(workspace, id, mode, original_prompt, return_value, chunk_id, cache_type, create_time, update_time)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
|
ON CONFLICT (workspace, mode, id) DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.execute(
|
||||||
|
insert_sql,
|
||||||
|
{
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"id": new_key,
|
||||||
|
"mode": record["mode"],
|
||||||
|
"original_prompt": record["original_prompt"],
|
||||||
|
"return_value": record["return_value"],
|
||||||
|
"chunk_id": record["chunk_id"],
|
||||||
|
"cache_type": cache_type, # Add cache_type field
|
||||||
|
"create_time": record["create_time"],
|
||||||
|
"update_time": record["update_time"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete old data
|
||||||
|
delete_sql = """
|
||||||
|
DELETE FROM LIGHTRAG_LLM_CACHE
|
||||||
|
WHERE workspace=$1 AND mode=$2 AND id=$3
|
||||||
|
"""
|
||||||
|
await self.execute(
|
||||||
|
delete_sql,
|
||||||
|
{
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"mode": record["mode"],
|
||||||
|
"id": record["id"], # Old id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
migrated_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to migrate cache record {record['id']}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully migrated {migrated_count} cache records to flattened format"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM cache migration failed: {e}")
|
||||||
|
# Don't raise exception, allow system to continue startup
|
||||||
|
|
||||||
|
async def _migrate_doc_status_add_chunks_list(self):
|
||||||
|
"""Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist"""
|
||||||
|
try:
|
||||||
|
# Check if chunks_list column exists
|
||||||
|
check_column_sql = """
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = 'lightrag_doc_status'
|
||||||
|
AND column_name = 'chunks_list'
|
||||||
|
"""
|
||||||
|
|
||||||
|
column_info = await self.query(check_column_sql)
|
||||||
|
if not column_info:
|
||||||
|
logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table")
|
||||||
|
add_column_sql = """
|
||||||
|
ALTER TABLE LIGHTRAG_DOC_STATUS
|
||||||
|
ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb
|
||||||
|
"""
|
||||||
|
await self.execute(add_column_sql)
|
||||||
|
logger.info(
|
||||||
|
"Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"chunks_list column already exists in LIGHTRAG_DOC_STATUS table"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _migrate_text_chunks_add_llm_cache_list(self):
|
||||||
|
"""Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist"""
|
||||||
|
try:
|
||||||
|
# Check if llm_cache_list column exists
|
||||||
|
check_column_sql = """
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = 'lightrag_doc_chunks'
|
||||||
|
AND column_name = 'llm_cache_list'
|
||||||
|
"""
|
||||||
|
|
||||||
|
column_info = await self.query(check_column_sql)
|
||||||
|
if not column_info:
|
||||||
|
logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table")
|
||||||
|
add_column_sql = """
|
||||||
|
ALTER TABLE LIGHTRAG_DOC_CHUNKS
|
||||||
|
ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb
|
||||||
|
"""
|
||||||
|
await self.execute(add_column_sql)
|
||||||
|
logger.info(
|
||||||
|
"Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
async def check_tables(self):
|
async def check_tables(self):
|
||||||
# First create all tables
|
# First create all tables
|
||||||
for k, v in TABLES.items():
|
for k, v in TABLES.items():
|
||||||
|
|
@ -240,6 +519,44 @@ class PostgreSQLDB:
|
||||||
logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}")
|
logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}")
|
||||||
# Don't throw an exception, allow the initialization process to continue
|
# Don't throw an exception, allow the initialization process to continue
|
||||||
|
|
||||||
|
# Migrate LLM cache table to add cache_type field if needed
|
||||||
|
try:
|
||||||
|
await self._migrate_llm_cache_add_cache_type()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"PostgreSQL, Failed to migrate LLM cache cache_type field: {e}"
|
||||||
|
)
|
||||||
|
# Don't throw an exception, allow the initialization process to continue
|
||||||
|
|
||||||
|
# Finally, attempt to migrate old doc chunks data if needed
|
||||||
|
try:
|
||||||
|
await self._migrate_doc_chunks_to_vdb_chunks()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
|
||||||
|
|
||||||
|
# Check and migrate LLM cache to flattened keys if needed
|
||||||
|
try:
|
||||||
|
if await self._check_llm_cache_needs_migration():
|
||||||
|
await self._migrate_llm_cache_to_flattened_keys()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
|
||||||
|
|
||||||
|
# Migrate doc status to add chunks_list field if needed
|
||||||
|
try:
|
||||||
|
await self._migrate_doc_status_add_chunks_list()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Migrate text chunks to add llm_cache_list field if needed
|
||||||
|
try:
|
||||||
|
await self._migrate_text_chunks_add_llm_cache_list()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
sql: str,
|
sql: str,
|
||||||
|
|
@ -423,74 +740,139 @@ class PGKVStorage(BaseKVStorage):
|
||||||
try:
|
try:
|
||||||
results = await self.db.query(sql, params, multirows=True)
|
results = await self.db.query(sql, params, multirows=True)
|
||||||
|
|
||||||
|
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
result_dict = {}
|
processed_results = {}
|
||||||
for row in results:
|
for row in results:
|
||||||
mode = row["mode"]
|
create_time = row.get("create_time", 0)
|
||||||
if mode not in result_dict:
|
update_time = row.get("update_time", 0)
|
||||||
result_dict[mode] = {}
|
# Map field names and add cache_type for compatibility
|
||||||
result_dict[mode][row["id"]] = row
|
processed_row = {
|
||||||
return result_dict
|
**row,
|
||||||
else:
|
"return": row.get("return_value", ""),
|
||||||
return {row["id"]: row for row in results}
|
"cache_type": row.get("original_prompt", "unknow"),
|
||||||
|
"original_prompt": row.get("original_prompt", ""),
|
||||||
|
"chunk_id": row.get("chunk_id"),
|
||||||
|
"mode": row.get("mode", "default"),
|
||||||
|
"create_time": create_time,
|
||||||
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
|
}
|
||||||
|
processed_results[row["id"]] = processed_row
|
||||||
|
return processed_results
|
||||||
|
|
||||||
|
# For text_chunks namespace, parse llm_cache_list JSON string back to list
|
||||||
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
|
processed_results = {}
|
||||||
|
for row in results:
|
||||||
|
llm_cache_list = row.get("llm_cache_list", [])
|
||||||
|
if isinstance(llm_cache_list, str):
|
||||||
|
try:
|
||||||
|
llm_cache_list = json.loads(llm_cache_list)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
llm_cache_list = []
|
||||||
|
row["llm_cache_list"] = llm_cache_list
|
||||||
|
create_time = row.get("create_time", 0)
|
||||||
|
update_time = row.get("update_time", 0)
|
||||||
|
row["create_time"] = create_time
|
||||||
|
row["update_time"] = (
|
||||||
|
create_time if update_time == 0 else update_time
|
||||||
|
)
|
||||||
|
processed_results[row["id"]] = row
|
||||||
|
return processed_results
|
||||||
|
|
||||||
|
# For other namespaces, return as-is
|
||||||
|
return {row["id"]: row for row in results}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
|
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
"""Get doc_full data by id."""
|
"""Get data by id."""
|
||||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
response = await self.db.query(sql, params)
|
||||||
array_res = await self.db.query(sql, params, multirows=True)
|
|
||||||
res = {}
|
|
||||||
for row in array_res:
|
|
||||||
res[row["id"]] = row
|
|
||||||
return res if res else None
|
|
||||||
else:
|
|
||||||
response = await self.db.query(sql, params)
|
|
||||||
return response if response else None
|
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
"""Specifically for llm_response_cache."""
|
# Parse llm_cache_list JSON string back to list
|
||||||
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
llm_cache_list = response.get("llm_cache_list", [])
|
||||||
params = {"workspace": self.db.workspace, "mode": mode, "id": id}
|
if isinstance(llm_cache_list, str):
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
try:
|
||||||
array_res = await self.db.query(sql, params, multirows=True)
|
llm_cache_list = json.loads(llm_cache_list)
|
||||||
res = {}
|
except json.JSONDecodeError:
|
||||||
for row in array_res:
|
llm_cache_list = []
|
||||||
res[row["id"]] = row
|
response["llm_cache_list"] = llm_cache_list
|
||||||
return res
|
create_time = response.get("create_time", 0)
|
||||||
else:
|
update_time = response.get("update_time", 0)
|
||||||
return None
|
response["create_time"] = create_time
|
||||||
|
response["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
|
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||||
|
if response and is_namespace(
|
||||||
|
self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
|
):
|
||||||
|
create_time = response.get("create_time", 0)
|
||||||
|
update_time = response.get("update_time", 0)
|
||||||
|
# Map field names and add cache_type for compatibility
|
||||||
|
response = {
|
||||||
|
**response,
|
||||||
|
"return": response.get("return_value", ""),
|
||||||
|
"cache_type": response.get("cache_type"),
|
||||||
|
"original_prompt": response.get("original_prompt", ""),
|
||||||
|
"chunk_id": response.get("chunk_id"),
|
||||||
|
"mode": response.get("mode", "default"),
|
||||||
|
"create_time": create_time,
|
||||||
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response if response else None
|
||||||
|
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
"""Get doc_chunks data by id"""
|
"""Get data by ids"""
|
||||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
)
|
)
|
||||||
params = {"workspace": self.db.workspace}
|
params = {"workspace": self.db.workspace}
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
results = await self.db.query(sql, params, multirows=True)
|
||||||
array_res = await self.db.query(sql, params, multirows=True)
|
|
||||||
modes = set()
|
|
||||||
dict_res: dict[str, dict] = {}
|
|
||||||
for row in array_res:
|
|
||||||
modes.add(row["mode"])
|
|
||||||
for mode in modes:
|
|
||||||
if mode not in dict_res:
|
|
||||||
dict_res[mode] = {}
|
|
||||||
for row in array_res:
|
|
||||||
dict_res[row["mode"]][row["id"]] = row
|
|
||||||
return [{k: v} for k, v in dict_res.items()]
|
|
||||||
else:
|
|
||||||
return await self.db.query(sql, params, multirows=True)
|
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
"""Specifically for llm_response_cache."""
|
# Parse llm_cache_list JSON string back to list for each result
|
||||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
for result in results:
|
||||||
params = {"workspace": self.db.workspace, "status": status}
|
llm_cache_list = result.get("llm_cache_list", [])
|
||||||
return await self.db.query(SQL, params, multirows=True)
|
if isinstance(llm_cache_list, str):
|
||||||
|
try:
|
||||||
|
llm_cache_list = json.loads(llm_cache_list)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
llm_cache_list = []
|
||||||
|
result["llm_cache_list"] = llm_cache_list
|
||||||
|
create_time = result.get("create_time", 0)
|
||||||
|
update_time = result.get("update_time", 0)
|
||||||
|
result["create_time"] = create_time
|
||||||
|
result["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
|
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||||
|
if results and is_namespace(
|
||||||
|
self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
|
):
|
||||||
|
processed_results = []
|
||||||
|
for row in results:
|
||||||
|
create_time = row.get("create_time", 0)
|
||||||
|
update_time = row.get("update_time", 0)
|
||||||
|
# Map field names and add cache_type for compatibility
|
||||||
|
processed_row = {
|
||||||
|
**row,
|
||||||
|
"return": row.get("return_value", ""),
|
||||||
|
"cache_type": row.get("cache_type"),
|
||||||
|
"original_prompt": row.get("original_prompt", ""),
|
||||||
|
"chunk_id": row.get("chunk_id"),
|
||||||
|
"mode": row.get("mode", "default"),
|
||||||
|
"create_time": create_time,
|
||||||
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
|
}
|
||||||
|
processed_results.append(processed_row)
|
||||||
|
return processed_results
|
||||||
|
|
||||||
|
return results if results else []
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
|
|
@ -520,7 +902,22 @@ class PGKVStorage(BaseKVStorage):
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
pass
|
current_time = datetime.datetime.now(timezone.utc)
|
||||||
|
for k, v in data.items():
|
||||||
|
upsert_sql = SQL_TEMPLATES["upsert_text_chunk"]
|
||||||
|
_data = {
|
||||||
|
"workspace": self.db.workspace,
|
||||||
|
"id": k,
|
||||||
|
"tokens": v["tokens"],
|
||||||
|
"chunk_order_index": v["chunk_order_index"],
|
||||||
|
"full_doc_id": v["full_doc_id"],
|
||||||
|
"content": v["content"],
|
||||||
|
"file_path": v["file_path"],
|
||||||
|
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
|
||||||
|
"create_time": current_time,
|
||||||
|
"update_time": current_time,
|
||||||
|
}
|
||||||
|
await self.db.execute(upsert_sql, _data)
|
||||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
|
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
|
||||||
|
|
@ -531,19 +928,21 @@ class PGKVStorage(BaseKVStorage):
|
||||||
}
|
}
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
for mode, items in data.items():
|
for k, v in data.items():
|
||||||
for k, v in items.items():
|
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
_data = {
|
||||||
_data = {
|
"workspace": self.db.workspace,
|
||||||
"workspace": self.db.workspace,
|
"id": k, # Use flattened key as id
|
||||||
"id": k,
|
"original_prompt": v["original_prompt"],
|
||||||
"original_prompt": v["original_prompt"],
|
"return_value": v["return"],
|
||||||
"return_value": v["return"],
|
"mode": v.get("mode", "default"), # Get mode from data
|
||||||
"mode": mode,
|
"chunk_id": v.get("chunk_id"),
|
||||||
"chunk_id": v.get("chunk_id"),
|
"cache_type": v.get(
|
||||||
}
|
"cache_type", "extract"
|
||||||
|
), # Get cache_type from data
|
||||||
|
}
|
||||||
|
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# PG handles persistence automatically
|
# PG handles persistence automatically
|
||||||
|
|
@ -949,8 +1348,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
else:
|
else:
|
||||||
exist_keys = []
|
exist_keys = []
|
||||||
new_keys = set([s for s in keys if s not in exist_keys])
|
new_keys = set([s for s in keys if s not in exist_keys])
|
||||||
print(f"keys: {keys}")
|
# print(f"keys: {keys}")
|
||||||
print(f"new_keys: {new_keys}")
|
# print(f"new_keys: {new_keys}")
|
||||||
return new_keys
|
return new_keys
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -965,6 +1364,14 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
if result is None or result == []:
|
if result is None or result == []:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
# Parse chunks_list JSON string back to list
|
||||||
|
chunks_list = result[0].get("chunks_list", [])
|
||||||
|
if isinstance(chunks_list, str):
|
||||||
|
try:
|
||||||
|
chunks_list = json.loads(chunks_list)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
chunks_list = []
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
content=result[0]["content"],
|
content=result[0]["content"],
|
||||||
content_length=result[0]["content_length"],
|
content_length=result[0]["content_length"],
|
||||||
|
|
@ -974,6 +1381,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
created_at=result[0]["created_at"],
|
created_at=result[0]["created_at"],
|
||||||
updated_at=result[0]["updated_at"],
|
updated_at=result[0]["updated_at"],
|
||||||
file_path=result[0]["file_path"],
|
file_path=result[0]["file_path"],
|
||||||
|
chunks_list=chunks_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
|
|
@ -988,19 +1396,32 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return []
|
return []
|
||||||
return [
|
|
||||||
{
|
processed_results = []
|
||||||
"content": row["content"],
|
for row in results:
|
||||||
"content_length": row["content_length"],
|
# Parse chunks_list JSON string back to list
|
||||||
"content_summary": row["content_summary"],
|
chunks_list = row.get("chunks_list", [])
|
||||||
"status": row["status"],
|
if isinstance(chunks_list, str):
|
||||||
"chunks_count": row["chunks_count"],
|
try:
|
||||||
"created_at": row["created_at"],
|
chunks_list = json.loads(chunks_list)
|
||||||
"updated_at": row["updated_at"],
|
except json.JSONDecodeError:
|
||||||
"file_path": row["file_path"],
|
chunks_list = []
|
||||||
}
|
|
||||||
for row in results
|
processed_results.append(
|
||||||
]
|
{
|
||||||
|
"content": row["content"],
|
||||||
|
"content_length": row["content_length"],
|
||||||
|
"content_summary": row["content_summary"],
|
||||||
|
"status": row["status"],
|
||||||
|
"chunks_count": row["chunks_count"],
|
||||||
|
"created_at": row["created_at"],
|
||||||
|
"updated_at": row["updated_at"],
|
||||||
|
"file_path": row["file_path"],
|
||||||
|
"chunks_list": chunks_list,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_results
|
||||||
|
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
|
|
@ -1021,8 +1442,18 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
||||||
params = {"workspace": self.db.workspace, "status": status.value}
|
params = {"workspace": self.db.workspace, "status": status.value}
|
||||||
result = await self.db.query(sql, params, True)
|
result = await self.db.query(sql, params, True)
|
||||||
docs_by_status = {
|
|
||||||
element["id"]: DocProcessingStatus(
|
docs_by_status = {}
|
||||||
|
for element in result:
|
||||||
|
# Parse chunks_list JSON string back to list
|
||||||
|
chunks_list = element.get("chunks_list", [])
|
||||||
|
if isinstance(chunks_list, str):
|
||||||
|
try:
|
||||||
|
chunks_list = json.loads(chunks_list)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
chunks_list = []
|
||||||
|
|
||||||
|
docs_by_status[element["id"]] = DocProcessingStatus(
|
||||||
content=element["content"],
|
content=element["content"],
|
||||||
content_summary=element["content_summary"],
|
content_summary=element["content_summary"],
|
||||||
content_length=element["content_length"],
|
content_length=element["content_length"],
|
||||||
|
|
@ -1031,9 +1462,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
updated_at=element["updated_at"],
|
updated_at=element["updated_at"],
|
||||||
chunks_count=element["chunks_count"],
|
chunks_count=element["chunks_count"],
|
||||||
file_path=element["file_path"],
|
file_path=element["file_path"],
|
||||||
|
chunks_list=chunks_list,
|
||||||
)
|
)
|
||||||
for element in result
|
|
||||||
}
|
|
||||||
return docs_by_status
|
return docs_by_status
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
|
|
@ -1097,10 +1528,10 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
logger.warning(f"Unable to parse datetime string: {dt_str}")
|
logger.warning(f"Unable to parse datetime string: {dt_str}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations
|
# Modified SQL to include created_at, updated_at, and chunks_list in both INSERT and UPDATE operations
|
||||||
# Both fields are updated from the input data in both INSERT and UPDATE cases
|
# All fields are updated from the input data in both INSERT and UPDATE cases
|
||||||
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at)
|
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,chunks_list,created_at,updated_at)
|
||||||
values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)
|
values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
|
||||||
on conflict(id,workspace) do update set
|
on conflict(id,workspace) do update set
|
||||||
content = EXCLUDED.content,
|
content = EXCLUDED.content,
|
||||||
content_summary = EXCLUDED.content_summary,
|
content_summary = EXCLUDED.content_summary,
|
||||||
|
|
@ -1108,6 +1539,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
chunks_count = EXCLUDED.chunks_count,
|
chunks_count = EXCLUDED.chunks_count,
|
||||||
status = EXCLUDED.status,
|
status = EXCLUDED.status,
|
||||||
file_path = EXCLUDED.file_path,
|
file_path = EXCLUDED.file_path,
|
||||||
|
chunks_list = EXCLUDED.chunks_list,
|
||||||
created_at = EXCLUDED.created_at,
|
created_at = EXCLUDED.created_at,
|
||||||
updated_at = EXCLUDED.updated_at"""
|
updated_at = EXCLUDED.updated_at"""
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
|
@ -1115,7 +1547,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
created_at = parse_datetime(v.get("created_at"))
|
created_at = parse_datetime(v.get("created_at"))
|
||||||
updated_at = parse_datetime(v.get("updated_at"))
|
updated_at = parse_datetime(v.get("updated_at"))
|
||||||
|
|
||||||
# chunks_count is optional
|
# chunks_count and chunks_list are optional
|
||||||
await self.db.execute(
|
await self.db.execute(
|
||||||
sql,
|
sql,
|
||||||
{
|
{
|
||||||
|
|
@ -1127,6 +1559,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
||||||
"status": v["status"],
|
"status": v["status"],
|
||||||
"file_path": v["file_path"],
|
"file_path": v["file_path"],
|
||||||
|
"chunks_list": json.dumps(v.get("chunks_list", [])),
|
||||||
"created_at": created_at, # Use the converted datetime object
|
"created_at": created_at, # Use the converted datetime object
|
||||||
"updated_at": updated_at, # Use the converted datetime object
|
"updated_at": updated_at, # Use the converted datetime object
|
||||||
},
|
},
|
||||||
|
|
@ -2409,7 +2842,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
NAMESPACE_TABLE_MAP = {
|
NAMESPACE_TABLE_MAP = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
||||||
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
|
||||||
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
|
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
|
||||||
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
|
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
|
||||||
NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
|
NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
|
||||||
|
|
@ -2438,6 +2871,21 @@ TABLES = {
|
||||||
},
|
},
|
||||||
"LIGHTRAG_DOC_CHUNKS": {
|
"LIGHTRAG_DOC_CHUNKS": {
|
||||||
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
||||||
|
id VARCHAR(255),
|
||||||
|
workspace VARCHAR(255),
|
||||||
|
full_doc_id VARCHAR(256),
|
||||||
|
chunk_order_index INTEGER,
|
||||||
|
tokens INTEGER,
|
||||||
|
content TEXT,
|
||||||
|
file_path VARCHAR(256),
|
||||||
|
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
|
||||||
|
create_time TIMESTAMP(0) WITH TIME ZONE,
|
||||||
|
update_time TIMESTAMP(0) WITH TIME ZONE,
|
||||||
|
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||||
|
)"""
|
||||||
|
},
|
||||||
|
"LIGHTRAG_VDB_CHUNKS": {
|
||||||
|
"ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
|
||||||
id VARCHAR(255),
|
id VARCHAR(255),
|
||||||
workspace VARCHAR(255),
|
workspace VARCHAR(255),
|
||||||
full_doc_id VARCHAR(256),
|
full_doc_id VARCHAR(256),
|
||||||
|
|
@ -2448,7 +2896,7 @@ TABLES = {
|
||||||
file_path VARCHAR(256),
|
file_path VARCHAR(256),
|
||||||
create_time TIMESTAMP(0) WITH TIME ZONE,
|
create_time TIMESTAMP(0) WITH TIME ZONE,
|
||||||
update_time TIMESTAMP(0) WITH TIME ZONE,
|
update_time TIMESTAMP(0) WITH TIME ZONE,
|
||||||
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
|
CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||||
)"""
|
)"""
|
||||||
},
|
},
|
||||||
"LIGHTRAG_VDB_ENTITY": {
|
"LIGHTRAG_VDB_ENTITY": {
|
||||||
|
|
@ -2503,6 +2951,7 @@ TABLES = {
|
||||||
chunks_count int4 NULL,
|
chunks_count int4 NULL,
|
||||||
status varchar(64) NULL,
|
status varchar(64) NULL,
|
||||||
file_path TEXT NULL,
|
file_path TEXT NULL,
|
||||||
|
chunks_list JSONB NULL DEFAULT '[]'::jsonb,
|
||||||
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
|
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
|
||||||
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
|
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
|
||||||
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
|
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
|
||||||
|
|
@ -2517,24 +2966,30 @@ SQL_TEMPLATES = {
|
||||||
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
|
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
|
||||||
""",
|
""",
|
||||||
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
||||||
chunk_order_index, full_doc_id, file_path
|
chunk_order_index, full_doc_id, file_path,
|
||||||
|
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
||||||
|
create_time, update_time
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
||||||
""",
|
""",
|
||||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
|
create_time, update_time
|
||||||
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
||||||
""",
|
""",
|
||||||
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
||||||
""",
|
""",
|
||||||
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
|
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
|
||||||
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
||||||
chunk_order_index, full_doc_id, file_path
|
chunk_order_index, full_doc_id, file_path,
|
||||||
|
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
||||||
|
create_time, update_time
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
|
create_time, update_time
|
||||||
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
||||||
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
|
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
|
||||||
|
|
@ -2542,16 +2997,31 @@ SQL_TEMPLATES = {
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET content = $2, update_time = CURRENT_TIMESTAMP
|
SET content = $2, update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id)
|
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
ON CONFLICT (workspace,mode,id) DO UPDATE
|
ON CONFLICT (workspace,mode,id) DO UPDATE
|
||||||
SET original_prompt = EXCLUDED.original_prompt,
|
SET original_prompt = EXCLUDED.original_prompt,
|
||||||
return_value=EXCLUDED.return_value,
|
return_value=EXCLUDED.return_value,
|
||||||
mode=EXCLUDED.mode,
|
mode=EXCLUDED.mode,
|
||||||
chunk_id=EXCLUDED.chunk_id,
|
chunk_id=EXCLUDED.chunk_id,
|
||||||
|
cache_type=EXCLUDED.cache_type,
|
||||||
update_time = CURRENT_TIMESTAMP
|
update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
||||||
|
chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
|
||||||
|
create_time, update_time)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
|
SET tokens=EXCLUDED.tokens,
|
||||||
|
chunk_order_index=EXCLUDED.chunk_order_index,
|
||||||
|
full_doc_id=EXCLUDED.full_doc_id,
|
||||||
|
content = EXCLUDED.content,
|
||||||
|
file_path=EXCLUDED.file_path,
|
||||||
|
llm_cache_list=EXCLUDED.llm_cache_list,
|
||||||
|
update_time = EXCLUDED.update_time
|
||||||
|
""",
|
||||||
|
# SQL for VectorStorage
|
||||||
|
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
||||||
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
||||||
create_time, update_time)
|
create_time, update_time)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
|
@ -2564,7 +3034,6 @@ SQL_TEMPLATES = {
|
||||||
file_path=EXCLUDED.file_path,
|
file_path=EXCLUDED.file_path,
|
||||||
update_time = EXCLUDED.update_time
|
update_time = EXCLUDED.update_time
|
||||||
""",
|
""",
|
||||||
# SQL for VectorStorage
|
|
||||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||||
content_vector, chunk_ids, file_path, create_time, update_time)
|
content_vector, chunk_ids, file_path, create_time, update_time)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
|
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
|
||||||
|
|
@ -2591,7 +3060,7 @@ SQL_TEMPLATES = {
|
||||||
"relationships": """
|
"relationships": """
|
||||||
WITH relevant_chunks AS (
|
WITH relevant_chunks AS (
|
||||||
SELECT id as chunk_id
|
SELECT id as chunk_id
|
||||||
FROM LIGHTRAG_DOC_CHUNKS
|
FROM LIGHTRAG_VDB_CHUNKS
|
||||||
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
|
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
|
||||||
)
|
)
|
||||||
SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at
|
SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at
|
||||||
|
|
@ -2608,7 +3077,7 @@ SQL_TEMPLATES = {
|
||||||
"entities": """
|
"entities": """
|
||||||
WITH relevant_chunks AS (
|
WITH relevant_chunks AS (
|
||||||
SELECT id as chunk_id
|
SELECT id as chunk_id
|
||||||
FROM LIGHTRAG_DOC_CHUNKS
|
FROM LIGHTRAG_VDB_CHUNKS
|
||||||
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
|
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
|
||||||
)
|
)
|
||||||
SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
|
SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
|
||||||
|
|
@ -2625,13 +3094,13 @@ SQL_TEMPLATES = {
|
||||||
"chunks": """
|
"chunks": """
|
||||||
WITH relevant_chunks AS (
|
WITH relevant_chunks AS (
|
||||||
SELECT id as chunk_id
|
SELECT id as chunk_id
|
||||||
FROM LIGHTRAG_DOC_CHUNKS
|
FROM LIGHTRAG_VDB_CHUNKS
|
||||||
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
|
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
|
||||||
)
|
)
|
||||||
SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
|
SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
|
||||||
(
|
(
|
||||||
SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
FROM LIGHTRAG_DOC_CHUNKS
|
FROM LIGHTRAG_VDB_CHUNKS
|
||||||
WHERE workspace=$1
|
WHERE workspace=$1
|
||||||
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
||||||
) as chunk_distances
|
) as chunk_distances
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, final
|
from typing import Any, final, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
import configparser
|
import configparser
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import threading
|
||||||
|
|
||||||
if not pm.is_installed("redis"):
|
if not pm.is_installed("redis"):
|
||||||
pm.install("redis")
|
pm.install("redis")
|
||||||
|
|
@ -13,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
|
||||||
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
from lightrag.base import BaseKVStorage
|
from lightrag.base import (
|
||||||
|
BaseKVStorage,
|
||||||
|
DocStatusStorage,
|
||||||
|
DocStatus,
|
||||||
|
DocProcessingStatus,
|
||||||
|
)
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,6 +32,41 @@ SOCKET_TIMEOUT = 5.0
|
||||||
SOCKET_CONNECT_TIMEOUT = 3.0
|
SOCKET_CONNECT_TIMEOUT = 3.0
|
||||||
|
|
||||||
|
|
||||||
|
class RedisConnectionManager:
|
||||||
|
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
||||||
|
|
||||||
|
_pools = {}
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
||||||
|
"""Get or create a connection pool for the given Redis URL"""
|
||||||
|
if redis_url not in cls._pools:
|
||||||
|
with cls._lock:
|
||||||
|
if redis_url not in cls._pools:
|
||||||
|
cls._pools[redis_url] = ConnectionPool.from_url(
|
||||||
|
redis_url,
|
||||||
|
max_connections=MAX_CONNECTIONS,
|
||||||
|
decode_responses=True,
|
||||||
|
socket_timeout=SOCKET_TIMEOUT,
|
||||||
|
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||||
|
)
|
||||||
|
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
||||||
|
return cls._pools[redis_url]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def close_all_pools(cls):
|
||||||
|
"""Close all connection pools (for cleanup)"""
|
||||||
|
with cls._lock:
|
||||||
|
for url, pool in cls._pools.items():
|
||||||
|
try:
|
||||||
|
pool.disconnect()
|
||||||
|
logger.info(f"Closed Redis connection pool for {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing Redis pool for {url}: {e}")
|
||||||
|
cls._pools.clear()
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class RedisKVStorage(BaseKVStorage):
|
class RedisKVStorage(BaseKVStorage):
|
||||||
|
|
@ -33,19 +74,28 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
redis_url = os.environ.get(
|
redis_url = os.environ.get(
|
||||||
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||||
)
|
)
|
||||||
# Create a connection pool with limits
|
# Use shared connection pool
|
||||||
self._pool = ConnectionPool.from_url(
|
self._pool = RedisConnectionManager.get_pool(redis_url)
|
||||||
redis_url,
|
|
||||||
max_connections=MAX_CONNECTIONS,
|
|
||||||
decode_responses=True,
|
|
||||||
socket_timeout=SOCKET_TIMEOUT,
|
|
||||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
|
||||||
)
|
|
||||||
self._redis = Redis(connection_pool=self._pool)
|
self._redis = Redis(connection_pool=self._pool)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
|
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize Redis connection and migrate legacy cache structure if needed"""
|
||||||
|
# Test connection
|
||||||
|
try:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
await redis.ping()
|
||||||
|
logger.info(f"Connected to Redis for namespace {self.namespace}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Redis: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Migrate legacy cache structure if this is a cache namespace
|
||||||
|
if self.namespace.endswith("_cache"):
|
||||||
|
await self._migrate_legacy_cache_structure()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _get_redis_connection(self):
|
async def _get_redis_connection(self):
|
||||||
"""Safe context manager for Redis operations."""
|
"""Safe context manager for Redis operations."""
|
||||||
|
|
@ -82,7 +132,13 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
try:
|
try:
|
||||||
data = await redis.get(f"{self.namespace}:{id}")
|
data = await redis.get(f"{self.namespace}:{id}")
|
||||||
return json.loads(data) if data else None
|
if data:
|
||||||
|
result = json.loads(data)
|
||||||
|
# Ensure time fields are present, provide default values for old data
|
||||||
|
result.setdefault("create_time", 0)
|
||||||
|
result.setdefault("update_time", 0)
|
||||||
|
return result
|
||||||
|
return None
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"JSON decode error for id {id}: {e}")
|
logger.error(f"JSON decode error for id {id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
@ -94,35 +150,113 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
for id in ids:
|
for id in ids:
|
||||||
pipe.get(f"{self.namespace}:{id}")
|
pipe.get(f"{self.namespace}:{id}")
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
return [json.loads(result) if result else None for result in results]
|
|
||||||
|
processed_results = []
|
||||||
|
for result in results:
|
||||||
|
if result:
|
||||||
|
data = json.loads(result)
|
||||||
|
# Ensure time fields are present for all documents
|
||||||
|
data.setdefault("create_time", 0)
|
||||||
|
data.setdefault("update_time", 0)
|
||||||
|
processed_results.append(data)
|
||||||
|
else:
|
||||||
|
processed_results.append(None)
|
||||||
|
|
||||||
|
return processed_results
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"JSON decode error in batch get: {e}")
|
logger.error(f"JSON decode error in batch get: {e}")
|
||||||
return [None] * len(ids)
|
return [None] * len(ids)
|
||||||
|
|
||||||
|
async def get_all(self) -> dict[str, Any]:
|
||||||
|
"""Get all data from storage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing all stored data
|
||||||
|
"""
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
# Get all keys for this namespace
|
||||||
|
keys = await redis.keys(f"{self.namespace}:*")
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Get all values in batch
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
# Build result dictionary
|
||||||
|
result = {}
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
if value:
|
||||||
|
# Extract the ID part (after namespace:)
|
||||||
|
key_id = key.split(":", 1)[1]
|
||||||
|
try:
|
||||||
|
data = json.loads(value)
|
||||||
|
# Ensure time fields are present for all documents
|
||||||
|
data.setdefault("create_time", 0)
|
||||||
|
data.setdefault("update_time", 0)
|
||||||
|
result[key_id] = data
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error for key {key}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting all data from Redis: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
for key in keys:
|
keys_list = list(keys) # Convert set to list for indexing
|
||||||
|
for key in keys_list:
|
||||||
pipe.exists(f"{self.namespace}:{key}")
|
pipe.exists(f"{self.namespace}:{key}")
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
|
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
||||||
return set(keys) - existing_ids
|
return set(keys) - existing_ids
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Inserting {len(data)} items to {self.namespace}")
|
import time
|
||||||
|
|
||||||
|
current_time = int(time.time()) # Get current Unix timestamp
|
||||||
|
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
try:
|
try:
|
||||||
|
# Check which keys already exist to determine create vs update
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for k in data.keys():
|
||||||
|
pipe.exists(f"{self.namespace}:{k}")
|
||||||
|
exists_results = await pipe.execute()
|
||||||
|
|
||||||
|
# Add timestamps to data
|
||||||
|
for i, (k, v) in enumerate(data.items()):
|
||||||
|
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||||
|
if "text_chunks" in self.namespace:
|
||||||
|
if "llm_cache_list" not in v:
|
||||||
|
v["llm_cache_list"] = []
|
||||||
|
|
||||||
|
# Add timestamps based on whether key exists
|
||||||
|
if exists_results[i]: # Key exists, only update update_time
|
||||||
|
v["update_time"] = current_time
|
||||||
|
else: # New key, set both create_time and update_time
|
||||||
|
v["create_time"] = current_time
|
||||||
|
v["update_time"] = current_time
|
||||||
|
|
||||||
|
v["_id"] = k
|
||||||
|
|
||||||
|
# Store the data
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||||
await pipe.execute()
|
await pipe.execute()
|
||||||
|
|
||||||
for k in data:
|
|
||||||
data[k]["_id"] = k
|
|
||||||
except json.JSONEncodeError as e:
|
except json.JSONEncodeError as e:
|
||||||
logger.error(f"JSON encode error during upsert: {e}")
|
logger.error(f"JSON encode error during upsert: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
@ -148,13 +282,13 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||||
"""Delete specific records from storage by by cache mode
|
"""Delete specific records from storage by cache mode
|
||||||
|
|
||||||
Importance notes for Redis storage:
|
Importance notes for Redis storage:
|
||||||
1. This will immediately delete the specified cache modes from Redis
|
1. This will immediately delete the specified cache modes from Redis
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
modes (list[str]): List of cache mode to be drop from storage
|
modes (list[str]): List of cache modes to be dropped from storage
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True: if the cache drop successfully
|
True: if the cache drop successfully
|
||||||
|
|
@ -164,9 +298,47 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.delete(modes)
|
async with self._get_redis_connection() as redis:
|
||||||
|
keys_to_delete = []
|
||||||
|
|
||||||
|
# Find matching keys for each mode using SCAN
|
||||||
|
for mode in modes:
|
||||||
|
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
|
||||||
|
pattern = f"{self.namespace}:{mode}:*"
|
||||||
|
cursor = 0
|
||||||
|
mode_keys = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=pattern, count=1000
|
||||||
|
)
|
||||||
|
if keys:
|
||||||
|
mode_keys.extend(keys)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
keys_to_delete.extend(mode_keys)
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if keys_to_delete:
|
||||||
|
# Batch delete
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys_to_delete:
|
||||||
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count = sum(results)
|
||||||
|
logger.info(
|
||||||
|
f"Dropped {deleted_count} cache entries for modes: {modes}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No cache entries found for modes: {modes}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(f"Error dropping cache by modes in Redis: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
|
|
@ -177,24 +349,370 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
"""
|
"""
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
try:
|
try:
|
||||||
keys = await redis.keys(f"{self.namespace}:*")
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
|
pattern = f"{self.namespace}:*"
|
||||||
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
if keys:
|
while True:
|
||||||
pipe = redis.pipeline()
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
for key in keys:
|
if keys:
|
||||||
pipe.delete(key)
|
# Delete keys in batches
|
||||||
results = await pipe.execute()
|
pipe = redis.pipeline()
|
||||||
deleted_count = sum(results)
|
for key in keys:
|
||||||
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count += sum(results)
|
||||||
|
|
||||||
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
if cursor == 0:
|
||||||
return {
|
break
|
||||||
"status": "success",
|
|
||||||
"message": f"{deleted_count} keys dropped",
|
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
||||||
}
|
return {
|
||||||
else:
|
"status": "success",
|
||||||
logger.info(f"No keys found to drop in {self.namespace}")
|
"message": f"{deleted_count} keys dropped",
|
||||||
return {"status": "success", "message": "no keys to drop"}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
async def _migrate_legacy_cache_structure(self):
|
||||||
|
"""Migrate legacy nested cache structure to flattened structure for Redis
|
||||||
|
|
||||||
|
Redis already stores data in a flattened way, but we need to check for
|
||||||
|
legacy keys that might contain nested JSON structures and migrate them.
|
||||||
|
|
||||||
|
Early exit if any flattened key is found (indicating migration already done).
|
||||||
|
"""
|
||||||
|
from lightrag.utils import generate_cache_key
|
||||||
|
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
# Get all keys for this namespace
|
||||||
|
keys = await redis.keys(f"{self.namespace}:*")
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we have any flattened keys already - if so, skip migration
|
||||||
|
has_flattened_keys = False
|
||||||
|
keys_to_migrate = []
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
# Extract the ID part (after namespace:)
|
||||||
|
key_id = key.split(":", 1)[1]
|
||||||
|
|
||||||
|
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
|
||||||
|
if ":" in key_id and len(key_id.split(":")) == 3:
|
||||||
|
has_flattened_keys = True
|
||||||
|
break # Early exit - migration already done
|
||||||
|
|
||||||
|
# Get the data to check if it's a legacy nested structure
|
||||||
|
data = await redis.get(key)
|
||||||
|
if data:
|
||||||
|
try:
|
||||||
|
parsed_data = json.loads(data)
|
||||||
|
# Check if this looks like a legacy cache mode with nested structure
|
||||||
|
if isinstance(parsed_data, dict) and all(
|
||||||
|
isinstance(v, dict) and "return" in v
|
||||||
|
for v in parsed_data.values()
|
||||||
|
):
|
||||||
|
keys_to_migrate.append((key, key_id, parsed_data))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If we found any flattened keys, assume migration is already done
|
||||||
|
if has_flattened_keys:
|
||||||
|
logger.debug(
|
||||||
|
f"Found flattened cache keys in {self.namespace}, skipping migration"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not keys_to_migrate:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Perform migration
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
migration_count = 0
|
||||||
|
|
||||||
|
for old_key, mode, nested_data in keys_to_migrate:
|
||||||
|
# Delete the old key
|
||||||
|
pipe.delete(old_key)
|
||||||
|
|
||||||
|
# Create new flattened keys
|
||||||
|
for cache_hash, cache_entry in nested_data.items():
|
||||||
|
cache_type = cache_entry.get("cache_type", "extract")
|
||||||
|
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
||||||
|
full_key = f"{self.namespace}:{flattened_key}"
|
||||||
|
pipe.set(full_key, json.dumps(cache_entry))
|
||||||
|
migration_count += 1
|
||||||
|
|
||||||
|
await pipe.execute()
|
||||||
|
|
||||||
|
if migration_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
@dataclass
|
||||||
|
class RedisDocStatusStorage(DocStatusStorage):
|
||||||
|
"""Redis implementation of document status storage"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
redis_url = os.environ.get(
|
||||||
|
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||||
|
)
|
||||||
|
# Use shared connection pool
|
||||||
|
self._pool = RedisConnectionManager.get_pool(redis_url)
|
||||||
|
self._redis = Redis(connection_pool=self._pool)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize Redis connection"""
|
||||||
|
try:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
await redis.ping()
|
||||||
|
logger.info(
|
||||||
|
f"Connected to Redis for doc status namespace {self.namespace}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _get_redis_connection(self):
|
||||||
|
"""Safe context manager for Redis operations."""
|
||||||
|
try:
|
||||||
|
yield self._redis
|
||||||
|
except ConnectionError as e:
|
||||||
|
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
|
||||||
|
raise
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the Redis connection."""
|
||||||
|
if hasattr(self, "_redis") and self._redis:
|
||||||
|
await self._redis.close()
|
||||||
|
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Support for async context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Ensure Redis resources are cleaned up when exiting context."""
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
keys_list = list(keys)
|
||||||
|
for key in keys_list:
|
||||||
|
pipe.exists(f"{self.namespace}:{key}")
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
||||||
|
return set(keys) - existing_ids
|
||||||
|
|
||||||
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
|
result: list[dict[str, Any]] = []
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for id in ids:
|
||||||
|
pipe.get(f"{self.namespace}:{id}")
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
for result_data in results:
|
||||||
|
if result_data:
|
||||||
|
try:
|
||||||
|
result.append(json.loads(result_data))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error in get_by_ids: {e}")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_by_ids: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
|
"""Get counts of documents in each status"""
|
||||||
|
counts = {status.value: 0 for status in DocStatus}
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
# Use SCAN to iterate through all keys in the namespace
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=f"{self.namespace}:*", count=1000
|
||||||
|
)
|
||||||
|
if keys:
|
||||||
|
# Get all values in batch
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
# Count statuses
|
||||||
|
for value in values:
|
||||||
|
if value:
|
||||||
|
try:
|
||||||
|
doc_data = json.loads(value)
|
||||||
|
status = doc_data.get("status")
|
||||||
|
if status in counts:
|
||||||
|
counts[status] += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting status counts: {e}")
|
||||||
|
|
||||||
|
return counts
|
||||||
|
|
||||||
|
async def get_docs_by_status(
|
||||||
|
self, status: DocStatus
|
||||||
|
) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""Get all documents with a specific status"""
|
||||||
|
result = {}
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
# Use SCAN to iterate through all keys in the namespace
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=f"{self.namespace}:*", count=1000
|
||||||
|
)
|
||||||
|
if keys:
|
||||||
|
# Get all values in batch
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
# Filter by status and create DocProcessingStatus objects
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
if value:
|
||||||
|
try:
|
||||||
|
doc_data = json.loads(value)
|
||||||
|
if doc_data.get("status") == status.value:
|
||||||
|
# Extract document ID from key
|
||||||
|
doc_id = key.split(":", 1)[1]
|
||||||
|
|
||||||
|
# Make a copy of the data to avoid modifying the original
|
||||||
|
data = doc_data.copy()
|
||||||
|
# If content is missing, use content_summary as content
|
||||||
|
if (
|
||||||
|
"content" not in data
|
||||||
|
and "content_summary" in data
|
||||||
|
):
|
||||||
|
data["content"] = data["content_summary"]
|
||||||
|
# If file_path is not in data, use document id as file path
|
||||||
|
if "file_path" not in data:
|
||||||
|
data["file_path"] = "no-file-path"
|
||||||
|
|
||||||
|
result[doc_id] = DocProcessingStatus(**data)
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing document {key}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting docs by status: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
"""Redis handles persistence automatically"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
"""Insert or update document status data"""
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
# Ensure chunks_list field exists for new documents
|
||||||
|
for doc_id, doc_data in data.items():
|
||||||
|
if "chunks_list" not in doc_data:
|
||||||
|
doc_data["chunks_list"] = []
|
||||||
|
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for k, v in data.items():
|
||||||
|
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||||
|
await pipe.execute()
|
||||||
|
except json.JSONEncodeError as e:
|
||||||
|
logger.error(f"JSON encode error during upsert: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
data = await redis.get(f"{self.namespace}:{id}")
|
||||||
|
return json.loads(data) if data else None
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error for id {id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete(self, doc_ids: list[str]) -> None:
|
||||||
|
"""Delete specific records from storage by their IDs"""
|
||||||
|
if not doc_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
pipe.delete(f"{self.namespace}:{doc_id}")
|
||||||
|
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count = sum(results)
|
||||||
|
logger.info(
|
||||||
|
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def drop(self) -> dict[str, str]:
|
||||||
|
"""Drop all document status data from storage and clean up resources"""
|
||||||
|
try:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
|
pattern = f"{self.namespace}:*"
|
||||||
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
|
if keys:
|
||||||
|
# Delete keys in batches
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count += sum(results)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Dropped {deleted_count} doc status keys from {self.namespace}"
|
||||||
|
)
|
||||||
|
return {"status": "success", "message": "data dropped"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
)
|
)
|
||||||
from lightrag.constants import (
|
from lightrag.constants import (
|
||||||
|
DEFAULT_MAX_GLEANING,
|
||||||
DEFAULT_MAX_TOKEN_SUMMARY,
|
DEFAULT_MAX_TOKEN_SUMMARY,
|
||||||
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
||||||
)
|
)
|
||||||
|
|
@ -124,7 +125,9 @@ class LightRAG:
|
||||||
# Entity extraction
|
# Entity extraction
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
entity_extract_max_gleaning: int = field(default=1)
|
entity_extract_max_gleaning: int = field(
|
||||||
|
default=get_env_value("MAX_GLEANING", DEFAULT_MAX_GLEANING, int)
|
||||||
|
)
|
||||||
"""Maximum number of entity extraction attempts for ambiguous content."""
|
"""Maximum number of entity extraction attempts for ambiguous content."""
|
||||||
|
|
||||||
summary_to_max_tokens: int = field(
|
summary_to_max_tokens: int = field(
|
||||||
|
|
@ -346,6 +349,7 @@ class LightRAG:
|
||||||
|
|
||||||
# Fix global_config now
|
# Fix global_config now
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
|
|
||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
|
|
@ -394,13 +398,13 @@ class LightRAG:
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: deprecating, text_chunks is redundant with chunks_vdb
|
|
||||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||||
|
|
@ -949,6 +953,7 @@ class LightRAG:
|
||||||
**dp,
|
**dp,
|
||||||
"full_doc_id": doc_id,
|
"full_doc_id": doc_id,
|
||||||
"file_path": file_path, # Add file path to each chunk
|
"file_path": file_path, # Add file path to each chunk
|
||||||
|
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
|
||||||
}
|
}
|
||||||
for dp in self.chunking_func(
|
for dp in self.chunking_func(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|
@ -960,14 +965,17 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Process document (text chunks and full docs) in parallel
|
# Process document in two stages
|
||||||
# Create tasks with references for potential cancellation
|
# Stage 1: Process text chunks and docs (parallel execution)
|
||||||
doc_status_task = asyncio.create_task(
|
doc_status_task = asyncio.create_task(
|
||||||
self.doc_status.upsert(
|
self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
doc_id: {
|
doc_id: {
|
||||||
"status": DocStatus.PROCESSING,
|
"status": DocStatus.PROCESSING,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
|
"chunks_list": list(
|
||||||
|
chunks.keys()
|
||||||
|
), # Save chunks list
|
||||||
"content": status_doc.content,
|
"content": status_doc.content,
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
"content_length": status_doc.content_length,
|
"content_length": status_doc.content_length,
|
||||||
|
|
@ -983,11 +991,6 @@ class LightRAG:
|
||||||
chunks_vdb_task = asyncio.create_task(
|
chunks_vdb_task = asyncio.create_task(
|
||||||
self.chunks_vdb.upsert(chunks)
|
self.chunks_vdb.upsert(chunks)
|
||||||
)
|
)
|
||||||
entity_relation_task = asyncio.create_task(
|
|
||||||
self._process_entity_relation_graph(
|
|
||||||
chunks, pipeline_status, pipeline_status_lock
|
|
||||||
)
|
|
||||||
)
|
|
||||||
full_docs_task = asyncio.create_task(
|
full_docs_task = asyncio.create_task(
|
||||||
self.full_docs.upsert(
|
self.full_docs.upsert(
|
||||||
{doc_id: {"content": status_doc.content}}
|
{doc_id: {"content": status_doc.content}}
|
||||||
|
|
@ -996,14 +999,26 @@ class LightRAG:
|
||||||
text_chunks_task = asyncio.create_task(
|
text_chunks_task = asyncio.create_task(
|
||||||
self.text_chunks.upsert(chunks)
|
self.text_chunks.upsert(chunks)
|
||||||
)
|
)
|
||||||
tasks = [
|
|
||||||
|
# First stage tasks (parallel execution)
|
||||||
|
first_stage_tasks = [
|
||||||
doc_status_task,
|
doc_status_task,
|
||||||
chunks_vdb_task,
|
chunks_vdb_task,
|
||||||
entity_relation_task,
|
|
||||||
full_docs_task,
|
full_docs_task,
|
||||||
text_chunks_task,
|
text_chunks_task,
|
||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
entity_relation_task = None
|
||||||
|
|
||||||
|
# Execute first stage tasks
|
||||||
|
await asyncio.gather(*first_stage_tasks)
|
||||||
|
|
||||||
|
# Stage 2: Process entity relation graph (after text_chunks are saved)
|
||||||
|
entity_relation_task = asyncio.create_task(
|
||||||
|
self._process_entity_relation_graph(
|
||||||
|
chunks, pipeline_status, pipeline_status_lock
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await entity_relation_task
|
||||||
file_extraction_stage_ok = True
|
file_extraction_stage_ok = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1018,14 +1033,14 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
pipeline_status["history_messages"].append(error_msg)
|
pipeline_status["history_messages"].append(error_msg)
|
||||||
|
|
||||||
# Cancel other tasks as they are no longer meaningful
|
# Cancel tasks that are not yet completed
|
||||||
for task in [
|
all_tasks = first_stage_tasks + (
|
||||||
chunks_vdb_task,
|
[entity_relation_task]
|
||||||
entity_relation_task,
|
if entity_relation_task
|
||||||
full_docs_task,
|
else []
|
||||||
text_chunks_task,
|
)
|
||||||
]:
|
for task in all_tasks:
|
||||||
if not task.done():
|
if task and not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
# Persistent llm cache
|
# Persistent llm cache
|
||||||
|
|
@ -1075,6 +1090,9 @@ class LightRAG:
|
||||||
doc_id: {
|
doc_id: {
|
||||||
"status": DocStatus.PROCESSED,
|
"status": DocStatus.PROCESSED,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
|
"chunks_list": list(
|
||||||
|
chunks.keys()
|
||||||
|
), # 保留 chunks_list
|
||||||
"content": status_doc.content,
|
"content": status_doc.content,
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
"content_length": status_doc.content_length,
|
"content_length": status_doc.content_length,
|
||||||
|
|
@ -1193,6 +1211,7 @@ class LightRAG:
|
||||||
pipeline_status=pipeline_status,
|
pipeline_status=pipeline_status,
|
||||||
pipeline_status_lock=pipeline_status_lock,
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
llm_response_cache=self.llm_response_cache,
|
llm_response_cache=self.llm_response_cache,
|
||||||
|
text_chunks_storage=self.text_chunks,
|
||||||
)
|
)
|
||||||
return chunk_results
|
return chunk_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1723,28 +1742,10 @@ class LightRAG:
|
||||||
file_path="",
|
file_path="",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Get all chunks related to this document
|
# 2. Get chunk IDs from document status
|
||||||
try:
|
chunk_ids = set(doc_status_data.get("chunks_list", []))
|
||||||
all_chunks = await self.text_chunks.get_all()
|
|
||||||
related_chunks = {
|
|
||||||
chunk_id: chunk_data
|
|
||||||
for chunk_id, chunk_data in all_chunks.items()
|
|
||||||
if isinstance(chunk_data, dict)
|
|
||||||
and chunk_data.get("full_doc_id") == doc_id
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update pipeline status after getting chunks count
|
if not chunk_ids:
|
||||||
async with pipeline_status_lock:
|
|
||||||
log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks"
|
|
||||||
logger.info(log_message)
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}")
|
|
||||||
raise Exception(f"Failed to retrieve document chunks: {e}") from e
|
|
||||||
|
|
||||||
if not related_chunks:
|
|
||||||
logger.warning(f"No chunks found for document {doc_id}")
|
logger.warning(f"No chunks found for document {doc_id}")
|
||||||
# Mark that deletion operations have started
|
# Mark that deletion operations have started
|
||||||
deletion_operations_started = True
|
deletion_operations_started = True
|
||||||
|
|
@ -1775,7 +1776,6 @@ class LightRAG:
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_ids = set(related_chunks.keys())
|
|
||||||
# Mark that deletion operations have started
|
# Mark that deletion operations have started
|
||||||
deletion_operations_started = True
|
deletion_operations_started = True
|
||||||
|
|
||||||
|
|
@ -1799,26 +1799,12 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update pipeline status after getting affected_nodes
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
log_message = f"Found {len(affected_nodes)} affected entities"
|
|
||||||
logger.info(log_message)
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
affected_edges = (
|
affected_edges = (
|
||||||
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
|
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
|
||||||
list(chunk_ids)
|
list(chunk_ids)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update pipeline status after getting affected_edges
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
log_message = f"Found {len(affected_edges)} affected relations"
|
|
||||||
logger.info(log_message)
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to analyze affected graph elements: {e}")
|
logger.error(f"Failed to analyze affected graph elements: {e}")
|
||||||
raise Exception(f"Failed to analyze graph dependencies: {e}") from e
|
raise Exception(f"Failed to analyze graph dependencies: {e}") from e
|
||||||
|
|
@ -1836,6 +1822,14 @@ class LightRAG:
|
||||||
elif remaining_sources != sources:
|
elif remaining_sources != sources:
|
||||||
entities_to_rebuild[node_label] = remaining_sources
|
entities_to_rebuild[node_label] = remaining_sources
|
||||||
|
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
log_message = (
|
||||||
|
f"Found {len(entities_to_rebuild)} affected entities"
|
||||||
|
)
|
||||||
|
logger.info(log_message)
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
# Process relationships
|
# Process relationships
|
||||||
for edge_data in affected_edges:
|
for edge_data in affected_edges:
|
||||||
src = edge_data.get("source")
|
src = edge_data.get("source")
|
||||||
|
|
@ -1857,6 +1851,14 @@ class LightRAG:
|
||||||
elif remaining_sources != sources:
|
elif remaining_sources != sources:
|
||||||
relationships_to_rebuild[edge_tuple] = remaining_sources
|
relationships_to_rebuild[edge_tuple] = remaining_sources
|
||||||
|
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
log_message = (
|
||||||
|
f"Found {len(relationships_to_rebuild)} affected relations"
|
||||||
|
)
|
||||||
|
logger.info(log_message)
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process graph analysis results: {e}")
|
logger.error(f"Failed to process graph analysis results: {e}")
|
||||||
raise Exception(f"Failed to process graph dependencies: {e}") from e
|
raise Exception(f"Failed to process graph dependencies: {e}") from e
|
||||||
|
|
@ -1940,17 +1942,13 @@ class LightRAG:
|
||||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
entities_vdb=self.entities_vdb,
|
entities_vdb=self.entities_vdb,
|
||||||
relationships_vdb=self.relationships_vdb,
|
relationships_vdb=self.relationships_vdb,
|
||||||
text_chunks=self.text_chunks,
|
text_chunks_storage=self.text_chunks,
|
||||||
llm_response_cache=self.llm_response_cache,
|
llm_response_cache=self.llm_response_cache,
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
|
pipeline_status=pipeline_status,
|
||||||
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
log_message = f"Successfully rebuilt {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relations"
|
|
||||||
logger.info(log_message)
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to rebuild knowledge from chunks: {e}")
|
logger.error(f"Failed to rebuild knowledge from chunks: {e}")
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from .utils import (
|
||||||
CacheData,
|
CacheData,
|
||||||
get_conversation_turns,
|
get_conversation_turns,
|
||||||
use_llm_func_with_cache,
|
use_llm_func_with_cache,
|
||||||
|
update_chunk_cache_list,
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
|
|
@ -103,8 +104,6 @@ async def _handle_entity_relation_summary(
|
||||||
entity_or_relation_name: str,
|
entity_or_relation_name: str,
|
||||||
description: str,
|
description: str,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
pipeline_status: dict = None,
|
|
||||||
pipeline_status_lock=None,
|
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle entity relation summary
|
"""Handle entity relation summary
|
||||||
|
|
@ -247,9 +246,11 @@ async def _rebuild_knowledge_from_chunks(
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks: BaseKVStorage,
|
text_chunks_storage: BaseKVStorage,
|
||||||
llm_response_cache: BaseKVStorage,
|
llm_response_cache: BaseKVStorage,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
|
pipeline_status: dict | None = None,
|
||||||
|
pipeline_status_lock=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Rebuild entity and relationship descriptions from cached extraction results
|
"""Rebuild entity and relationship descriptions from cached extraction results
|
||||||
|
|
||||||
|
|
@ -259,9 +260,12 @@ async def _rebuild_knowledge_from_chunks(
|
||||||
Args:
|
Args:
|
||||||
entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
|
entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
|
||||||
relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
|
relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
|
||||||
|
text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data}
|
||||||
"""
|
"""
|
||||||
if not entities_to_rebuild and not relationships_to_rebuild:
|
if not entities_to_rebuild and not relationships_to_rebuild:
|
||||||
return
|
return
|
||||||
|
rebuilt_entities_count = 0
|
||||||
|
rebuilt_relationships_count = 0
|
||||||
|
|
||||||
# Get all referenced chunk IDs
|
# Get all referenced chunk IDs
|
||||||
all_referenced_chunk_ids = set()
|
all_referenced_chunk_ids = set()
|
||||||
|
|
@ -270,36 +274,74 @@ async def _rebuild_knowledge_from_chunks(
|
||||||
for chunk_ids in relationships_to_rebuild.values():
|
for chunk_ids in relationships_to_rebuild.values():
|
||||||
all_referenced_chunk_ids.update(chunk_ids)
|
all_referenced_chunk_ids.update(chunk_ids)
|
||||||
|
|
||||||
logger.debug(
|
status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
|
||||||
f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
|
logger.info(status_message)
|
||||||
)
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
|
|
||||||
# Get cached extraction results for these chunks
|
# Get cached extraction results for these chunks using storage
|
||||||
|
# cached_results: chunk_id -> [list of extraction result from LLM cache sorted by created_at]
|
||||||
cached_results = await _get_cached_extraction_results(
|
cached_results = await _get_cached_extraction_results(
|
||||||
llm_response_cache, all_referenced_chunk_ids
|
llm_response_cache,
|
||||||
|
all_referenced_chunk_ids,
|
||||||
|
text_chunks_storage=text_chunks_storage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cached_results:
|
if not cached_results:
|
||||||
logger.warning("No cached extraction results found, cannot rebuild")
|
status_message = "No cached extraction results found, cannot rebuild"
|
||||||
|
logger.warning(status_message)
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Process cached results to get entities and relationships for each chunk
|
# Process cached results to get entities and relationships for each chunk
|
||||||
chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
|
chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
|
||||||
chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
|
chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
|
||||||
|
|
||||||
for chunk_id, extraction_result in cached_results.items():
|
for chunk_id, extraction_results in cached_results.items():
|
||||||
try:
|
try:
|
||||||
entities, relationships = await _parse_extraction_result(
|
# Handle multiple extraction results per chunk
|
||||||
text_chunks=text_chunks,
|
chunk_entities[chunk_id] = defaultdict(list)
|
||||||
extraction_result=extraction_result,
|
chunk_relationships[chunk_id] = defaultdict(list)
|
||||||
chunk_id=chunk_id,
|
|
||||||
)
|
# process multiple LLM extraction results for a single chunk_id
|
||||||
chunk_entities[chunk_id] = entities
|
for extraction_result in extraction_results:
|
||||||
chunk_relationships[chunk_id] = relationships
|
entities, relationships = await _parse_extraction_result(
|
||||||
|
text_chunks_storage=text_chunks_storage,
|
||||||
|
extraction_result=extraction_result,
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge entities and relationships from this extraction result
|
||||||
|
# Only keep the first occurrence of each entity_name in the same chunk_id
|
||||||
|
for entity_name, entity_list in entities.items():
|
||||||
|
if (
|
||||||
|
entity_name not in chunk_entities[chunk_id]
|
||||||
|
or len(chunk_entities[chunk_id][entity_name]) == 0
|
||||||
|
):
|
||||||
|
chunk_entities[chunk_id][entity_name].extend(entity_list)
|
||||||
|
|
||||||
|
# Only keep the first occurrence of each rel_key in the same chunk_id
|
||||||
|
for rel_key, rel_list in relationships.items():
|
||||||
|
if (
|
||||||
|
rel_key not in chunk_relationships[chunk_id]
|
||||||
|
or len(chunk_relationships[chunk_id][rel_key]) == 0
|
||||||
|
):
|
||||||
|
chunk_relationships[chunk_id][rel_key].extend(rel_list)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
status_message = (
|
||||||
f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
|
f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
|
||||||
)
|
)
|
||||||
|
logger.info(status_message) # Per requirement, change to info
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Rebuild entities
|
# Rebuild entities
|
||||||
|
|
@ -314,11 +356,22 @@ async def _rebuild_knowledge_from_chunks(
|
||||||
llm_response_cache=llm_response_cache,
|
llm_response_cache=llm_response_cache,
|
||||||
global_config=global_config,
|
global_config=global_config,
|
||||||
)
|
)
|
||||||
logger.debug(
|
rebuilt_entities_count += 1
|
||||||
f"Rebuilt entity {entity_name} from {len(chunk_ids)} cached extractions"
|
status_message = (
|
||||||
|
f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks"
|
||||||
)
|
)
|
||||||
|
logger.info(status_message)
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to rebuild entity {entity_name}: {e}")
|
status_message = f"Failed to rebuild entity {entity_name}: {e}"
|
||||||
|
logger.info(status_message) # Per requirement, change to info
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
|
|
||||||
# Rebuild relationships
|
# Rebuild relationships
|
||||||
for (src, tgt), chunk_ids in relationships_to_rebuild.items():
|
for (src, tgt), chunk_ids in relationships_to_rebuild.items():
|
||||||
|
|
@ -333,53 +386,112 @@ async def _rebuild_knowledge_from_chunks(
|
||||||
llm_response_cache=llm_response_cache,
|
llm_response_cache=llm_response_cache,
|
||||||
global_config=global_config,
|
global_config=global_config,
|
||||||
)
|
)
|
||||||
logger.debug(
|
rebuilt_relationships_count += 1
|
||||||
f"Rebuilt relationship {src}-{tgt} from {len(chunk_ids)} cached extractions"
|
status_message = (
|
||||||
|
f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks"
|
||||||
)
|
)
|
||||||
|
logger.info(status_message)
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to rebuild relationship {src}-{tgt}: {e}")
|
status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}"
|
||||||
|
logger.info(status_message)
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
|
|
||||||
logger.debug("Completed rebuilding knowledge from cached extractions")
|
status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships."
|
||||||
|
logger.info(status_message)
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = status_message
|
||||||
|
pipeline_status["history_messages"].append(status_message)
|
||||||
|
|
||||||
|
|
||||||
async def _get_cached_extraction_results(
|
async def _get_cached_extraction_results(
|
||||||
llm_response_cache: BaseKVStorage, chunk_ids: set[str]
|
llm_response_cache: BaseKVStorage,
|
||||||
) -> dict[str, str]:
|
chunk_ids: set[str],
|
||||||
|
text_chunks_storage: BaseKVStorage,
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
"""Get cached extraction results for specific chunk IDs
|
"""Get cached extraction results for specific chunk IDs
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
llm_response_cache: LLM response cache storage
|
||||||
chunk_ids: Set of chunk IDs to get cached results for
|
chunk_ids: Set of chunk IDs to get cached results for
|
||||||
|
text_chunks_data: Pre-loaded chunk data (optional, for performance)
|
||||||
|
text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping chunk_id -> extraction_result_text
|
Dict mapping chunk_id -> list of extraction_result_text
|
||||||
"""
|
"""
|
||||||
cached_results = {}
|
cached_results = {}
|
||||||
|
|
||||||
# Get all cached data for "default" mode (entity extraction cache)
|
# Collect all LLM cache IDs from chunks
|
||||||
default_cache = await llm_response_cache.get_by_id("default") or {}
|
all_cache_ids = set()
|
||||||
|
|
||||||
for cache_key, cache_entry in default_cache.items():
|
# Read from storage
|
||||||
|
chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
|
||||||
|
for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
|
||||||
|
if chunk_data and isinstance(chunk_data, dict):
|
||||||
|
llm_cache_list = chunk_data.get("llm_cache_list", [])
|
||||||
|
if llm_cache_list:
|
||||||
|
all_cache_ids.update(llm_cache_list)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not all_cache_ids:
|
||||||
|
logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
|
||||||
|
return cached_results
|
||||||
|
|
||||||
|
# Batch get LLM cache entries
|
||||||
|
cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids))
|
||||||
|
|
||||||
|
# Process cache entries and group by chunk_id
|
||||||
|
valid_entries = 0
|
||||||
|
for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
|
||||||
if (
|
if (
|
||||||
isinstance(cache_entry, dict)
|
cache_entry is not None
|
||||||
|
and isinstance(cache_entry, dict)
|
||||||
and cache_entry.get("cache_type") == "extract"
|
and cache_entry.get("cache_type") == "extract"
|
||||||
and cache_entry.get("chunk_id") in chunk_ids
|
and cache_entry.get("chunk_id") in chunk_ids
|
||||||
):
|
):
|
||||||
chunk_id = cache_entry["chunk_id"]
|
chunk_id = cache_entry["chunk_id"]
|
||||||
extraction_result = cache_entry["return"]
|
extraction_result = cache_entry["return"]
|
||||||
cached_results[chunk_id] = extraction_result
|
create_time = cache_entry.get(
|
||||||
|
"create_time", 0
|
||||||
|
) # Get creation time, default to 0
|
||||||
|
valid_entries += 1
|
||||||
|
|
||||||
logger.debug(
|
# Support multiple LLM caches per chunk
|
||||||
f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs"
|
if chunk_id not in cached_results:
|
||||||
|
cached_results[chunk_id] = []
|
||||||
|
# Store tuple with extraction result and creation time for sorting
|
||||||
|
cached_results[chunk_id].append((extraction_result, create_time))
|
||||||
|
|
||||||
|
# Sort extraction results by create_time for each chunk
|
||||||
|
for chunk_id in cached_results:
|
||||||
|
# Sort by create_time (x[1]), then extract only extraction_result (x[0])
|
||||||
|
cached_results[chunk_id].sort(key=lambda x: x[1])
|
||||||
|
cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
|
||||||
)
|
)
|
||||||
return cached_results
|
return cached_results
|
||||||
|
|
||||||
|
|
||||||
async def _parse_extraction_result(
|
async def _parse_extraction_result(
|
||||||
text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str
|
text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
"""Parse cached extraction result using the same logic as extract_entities
|
"""Parse cached extraction result using the same logic as extract_entities
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
text_chunks_storage: Text chunks storage to get chunk data
|
||||||
extraction_result: The cached LLM extraction result
|
extraction_result: The cached LLM extraction result
|
||||||
chunk_id: The chunk ID for source tracking
|
chunk_id: The chunk ID for source tracking
|
||||||
|
|
||||||
|
|
@ -387,8 +499,8 @@ async def _parse_extraction_result(
|
||||||
Tuple of (entities_dict, relationships_dict)
|
Tuple of (entities_dict, relationships_dict)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get chunk data for file_path
|
# Get chunk data for file_path from storage
|
||||||
chunk_data = await text_chunks.get_by_id(chunk_id)
|
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
|
||||||
file_path = (
|
file_path = (
|
||||||
chunk_data.get("file_path", "unknown_source")
|
chunk_data.get("file_path", "unknown_source")
|
||||||
if chunk_data
|
if chunk_data
|
||||||
|
|
@ -761,8 +873,6 @@ async def _merge_nodes_then_upsert(
|
||||||
entity_name,
|
entity_name,
|
||||||
description,
|
description,
|
||||||
global_config,
|
global_config,
|
||||||
pipeline_status,
|
|
||||||
pipeline_status_lock,
|
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -925,8 +1035,6 @@ async def _merge_edges_then_upsert(
|
||||||
f"({src_id}, {tgt_id})",
|
f"({src_id}, {tgt_id})",
|
||||||
description,
|
description,
|
||||||
global_config,
|
global_config,
|
||||||
pipeline_status,
|
|
||||||
pipeline_status_lock,
|
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1102,6 +1210,7 @@ async def extract_entities(
|
||||||
pipeline_status: dict = None,
|
pipeline_status: dict = None,
|
||||||
pipeline_status_lock=None,
|
pipeline_status_lock=None,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
|
text_chunks_storage: BaseKVStorage | None = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
|
|
@ -1208,6 +1317,9 @@ async def extract_entities(
|
||||||
# Get file path from chunk data or use default
|
# Get file path from chunk data or use default
|
||||||
file_path = chunk_dp.get("file_path", "unknown_source")
|
file_path = chunk_dp.get("file_path", "unknown_source")
|
||||||
|
|
||||||
|
# Create cache keys collector for batch processing
|
||||||
|
cache_keys_collector = []
|
||||||
|
|
||||||
# Get initial extraction
|
# Get initial extraction
|
||||||
hint_prompt = entity_extract_prompt.format(
|
hint_prompt = entity_extract_prompt.format(
|
||||||
**{**context_base, "input_text": content}
|
**{**context_base, "input_text": content}
|
||||||
|
|
@ -1219,7 +1331,10 @@ async def extract_entities(
|
||||||
llm_response_cache=llm_response_cache,
|
llm_response_cache=llm_response_cache,
|
||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
chunk_id=chunk_key,
|
chunk_id=chunk_key,
|
||||||
|
cache_keys_collector=cache_keys_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
|
||||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||||
|
|
||||||
# Process initial extraction with file path
|
# Process initial extraction with file path
|
||||||
|
|
@ -1236,6 +1351,7 @@ async def extract_entities(
|
||||||
history_messages=history,
|
history_messages=history,
|
||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
chunk_id=chunk_key,
|
chunk_id=chunk_key,
|
||||||
|
cache_keys_collector=cache_keys_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||||
|
|
@ -1266,11 +1382,21 @@ async def extract_entities(
|
||||||
llm_response_cache=llm_response_cache,
|
llm_response_cache=llm_response_cache,
|
||||||
history_messages=history,
|
history_messages=history,
|
||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
|
cache_keys_collector=cache_keys_collector,
|
||||||
)
|
)
|
||||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||||
if if_loop_result != "yes":
|
if if_loop_result != "yes":
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Batch update chunk's llm_cache_list with all collected cache keys
|
||||||
|
if cache_keys_collector and text_chunks_storage:
|
||||||
|
await update_chunk_cache_list(
|
||||||
|
chunk_key,
|
||||||
|
text_chunks_storage,
|
||||||
|
cache_keys_collector,
|
||||||
|
"entity_extraction",
|
||||||
|
)
|
||||||
|
|
||||||
processed_chunks += 1
|
processed_chunks += 1
|
||||||
entities_count = len(maybe_nodes)
|
entities_count = len(maybe_nodes)
|
||||||
relations_count = len(maybe_edges)
|
relations_count = len(maybe_edges)
|
||||||
|
|
@ -1343,7 +1469,7 @@ async def kg_query(
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
# Handle cache
|
# Handle cache
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
@ -1390,7 +1516,7 @@ async def kg_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
return context
|
return context if context is not None else PROMPTS["fail_response"]
|
||||||
if context is None:
|
if context is None:
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
|
@ -1502,7 +1628,7 @@ async def extract_keywords_only(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. Handle cache if needed - add cache type for keywords
|
# 1. Handle cache if needed - add cache type for keywords
|
||||||
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
|
args_hash = compute_args_hash(param.mode, text)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||||||
)
|
)
|
||||||
|
|
@ -1647,7 +1773,7 @@ async def _get_vector_context(
|
||||||
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
|
f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not maybe_trun_chunks:
|
if not maybe_trun_chunks:
|
||||||
|
|
@ -1871,7 +1997,7 @@ async def _get_node_data(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
|
f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
|
||||||
)
|
)
|
||||||
|
|
||||||
# build prompt
|
# build prompt
|
||||||
|
|
@ -2180,7 +2306,7 @@ async def _get_edge_data(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
|
f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
|
||||||
)
|
)
|
||||||
|
|
||||||
relations_context = []
|
relations_context = []
|
||||||
|
|
@ -2369,7 +2495,7 @@ async def naive_query(
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
# Handle cache
|
# Handle cache
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
@ -2485,7 +2611,7 @@ async def kg_query_with_keywords(
|
||||||
# Apply higher priority (5) to query relation LLM function
|
# Apply higher priority (5) to query relation LLM function
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightrag.prompt import PROMPTS
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from lightrag.constants import (
|
from lightrag.constants import (
|
||||||
DEFAULT_LOG_MAX_BYTES,
|
DEFAULT_LOG_MAX_BYTES,
|
||||||
|
|
@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
def compute_args_hash(*args: Any) -> str:
|
||||||
"""Compute a hash for the given arguments.
|
"""Compute a hash for the given arguments.
|
||||||
Args:
|
Args:
|
||||||
*args: Arguments to hash
|
*args: Arguments to hash
|
||||||
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Hash string
|
str: Hash string
|
||||||
"""
|
"""
|
||||||
|
|
@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||||
|
|
||||||
# Convert all arguments to strings and join them
|
# Convert all arguments to strings and join them
|
||||||
args_str = "".join([str(arg) for arg in args])
|
args_str = "".join([str(arg) for arg in args])
|
||||||
if cache_type:
|
|
||||||
args_str = f"{cache_type}:{args_str}"
|
|
||||||
|
|
||||||
# Compute MD5 hash
|
# Compute MD5 hash
|
||||||
return hashlib.md5(args_str.encode()).hexdigest()
|
return hashlib.md5(args_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
|
||||||
|
"""Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Cache mode (e.g., 'default', 'local', 'global')
|
||||||
|
cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
|
||||||
|
hash_value: Hash value from compute_args_hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Flattened cache key
|
||||||
|
"""
|
||||||
|
return f"{mode}:{cache_type}:{hash_value}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
|
||||||
|
"""Parse a flattened cache key back into its components
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
|
||||||
|
"""
|
||||||
|
parts = cache_key.split(":", 2)
|
||||||
|
if len(parts) == 3:
|
||||||
|
return parts[0], parts[1], parts[2]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Compute a unique ID for a given content string.
|
Compute a unique ID for a given content string.
|
||||||
|
|
@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
|
||||||
return combined_data
|
return combined_data
|
||||||
|
|
||||||
|
|
||||||
async def get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding,
|
|
||||||
similarity_threshold=0.95,
|
|
||||||
mode="default",
|
|
||||||
use_llm_check=False,
|
|
||||||
llm_func=None,
|
|
||||||
original_prompt=None,
|
|
||||||
cache_type=None,
|
|
||||||
) -> str | None:
|
|
||||||
logger.debug(
|
|
||||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
|
||||||
)
|
|
||||||
mode_cache = await hashing_kv.get_by_id(mode)
|
|
||||||
if not mode_cache:
|
|
||||||
return None
|
|
||||||
|
|
||||||
best_similarity = -1
|
|
||||||
best_response = None
|
|
||||||
best_prompt = None
|
|
||||||
best_cache_id = None
|
|
||||||
|
|
||||||
# Only iterate through cache entries for this mode
|
|
||||||
for cache_id, cache_data in mode_cache.items():
|
|
||||||
# Skip if cache_type doesn't match
|
|
||||||
if cache_type and cache_data.get("cache_type") != cache_type:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if cache data is valid
|
|
||||||
if cache_data["embedding"] is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Safely convert cached embedding
|
|
||||||
cached_quantized = np.frombuffer(
|
|
||||||
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
|
||||||
).reshape(cache_data["embedding_shape"])
|
|
||||||
|
|
||||||
# Ensure min_val and max_val are valid float values
|
|
||||||
embedding_min = cache_data.get("embedding_min")
|
|
||||||
embedding_max = cache_data.get("embedding_max")
|
|
||||||
|
|
||||||
if (
|
|
||||||
embedding_min is None
|
|
||||||
or embedding_max is None
|
|
||||||
or embedding_min >= embedding_max
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cached_embedding = dequantize_embedding(
|
|
||||||
cached_quantized,
|
|
||||||
embedding_min,
|
|
||||||
embedding_max,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error processing cached embedding: {str(e)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
similarity = cosine_similarity(current_embedding, cached_embedding)
|
|
||||||
if similarity > best_similarity:
|
|
||||||
best_similarity = similarity
|
|
||||||
best_response = cache_data["return"]
|
|
||||||
best_prompt = cache_data["original_prompt"]
|
|
||||||
best_cache_id = cache_id
|
|
||||||
|
|
||||||
if best_similarity > similarity_threshold:
|
|
||||||
# If LLM check is enabled and all required parameters are provided
|
|
||||||
if (
|
|
||||||
use_llm_check
|
|
||||||
and llm_func
|
|
||||||
and original_prompt
|
|
||||||
and best_prompt
|
|
||||||
and best_response is not None
|
|
||||||
):
|
|
||||||
compare_prompt = PROMPTS["similarity_check"].format(
|
|
||||||
original_prompt=original_prompt, cached_prompt=best_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm_result = await llm_func(compare_prompt)
|
|
||||||
llm_result = llm_result.strip()
|
|
||||||
llm_similarity = float(llm_result)
|
|
||||||
|
|
||||||
# Replace vector similarity with LLM similarity score
|
|
||||||
best_similarity = llm_similarity
|
|
||||||
if best_similarity < similarity_threshold:
|
|
||||||
log_data = {
|
|
||||||
"event": "cache_rejected_by_llm",
|
|
||||||
"type": cache_type,
|
|
||||||
"mode": mode,
|
|
||||||
"original_question": original_prompt[:100] + "..."
|
|
||||||
if len(original_prompt) > 100
|
|
||||||
else original_prompt,
|
|
||||||
"cached_question": best_prompt[:100] + "..."
|
|
||||||
if len(best_prompt) > 100
|
|
||||||
else best_prompt,
|
|
||||||
"similarity_score": round(best_similarity, 4),
|
|
||||||
"threshold": similarity_threshold,
|
|
||||||
}
|
|
||||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
|
||||||
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
|
||||||
return None
|
|
||||||
except Exception as e: # Catch all possible exceptions
|
|
||||||
logger.warning(f"LLM similarity check failed: {e}")
|
|
||||||
return None # Return None directly when LLM check fails
|
|
||||||
|
|
||||||
prompt_display = (
|
|
||||||
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
|
||||||
)
|
|
||||||
log_data = {
|
|
||||||
"event": "cache_hit",
|
|
||||||
"type": cache_type,
|
|
||||||
"mode": mode,
|
|
||||||
"similarity": round(best_similarity, 4),
|
|
||||||
"cache_id": best_cache_id,
|
|
||||||
"original_prompt": prompt_display,
|
|
||||||
}
|
|
||||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
|
||||||
return best_response
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
"""Calculate cosine similarity between two vectors"""
|
"""Calculate cosine similarity between two vectors"""
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
|
|
@ -957,7 +857,7 @@ async def handle_cache(
|
||||||
mode="default",
|
mode="default",
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
):
|
):
|
||||||
"""Generic cache handling function"""
|
"""Generic cache handling function with flattened cache keys"""
|
||||||
if hashing_kv is None:
|
if hashing_kv is None:
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
@ -968,15 +868,14 @@ async def handle_cache(
|
||||||
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
||||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
flattened_key = generate_cache_key(mode, cache_type, args_hash)
|
||||||
else:
|
cache_entry = await hashing_kv.get_by_id(flattened_key)
|
||||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
if cache_entry:
|
||||||
if args_hash in mode_cache:
|
logger.debug(f"Flattened cache hit(key:{flattened_key})")
|
||||||
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
return cache_entry["return"], None, None, None
|
||||||
return mode_cache[args_hash]["return"], None, None, None
|
|
||||||
|
|
||||||
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -994,7 +893,7 @@ class CacheData:
|
||||||
|
|
||||||
|
|
||||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
"""Save data to cache using flattened key structure.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hashing_kv: The key-value storage for caching
|
hashing_kv: The key-value storage for caching
|
||||||
|
|
@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
logger.debug("Streaming response detected, skipping cache")
|
logger.debug("Streaming response detected, skipping cache")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get existing cache data
|
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
flattened_key = generate_cache_key(
|
||||||
mode_cache = (
|
cache_data.mode, cache_data.cache_type, cache_data.args_hash
|
||||||
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
)
|
||||||
or {}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
|
||||||
|
|
||||||
# Check if we already have identical content cached
|
# Check if we already have identical content cached
|
||||||
if cache_data.args_hash in mode_cache:
|
existing_cache = await hashing_kv.get_by_id(flattened_key)
|
||||||
existing_content = mode_cache[cache_data.args_hash].get("return")
|
if existing_cache:
|
||||||
|
existing_content = existing_cache.get("return")
|
||||||
if existing_content == cache_data.content:
|
if existing_content == cache_data.content:
|
||||||
logger.info(
|
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
|
||||||
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Update cache with new content
|
# Create cache entry with flattened structure
|
||||||
mode_cache[cache_data.args_hash] = {
|
cache_entry = {
|
||||||
"return": cache_data.content,
|
"return": cache_data.content,
|
||||||
"cache_type": cache_data.cache_type,
|
"cache_type": cache_data.cache_type,
|
||||||
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
||||||
|
|
@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
"original_prompt": cache_data.prompt,
|
"original_prompt": cache_data.prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
|
logger.info(f" == LLM cache == saving: {flattened_key}")
|
||||||
|
|
||||||
# Only upsert if there's actual new content
|
# Save using flattened key
|
||||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
await hashing_kv.upsert({flattened_key: cache_entry})
|
||||||
|
|
||||||
|
|
||||||
def safe_unicode_decode(content):
|
def safe_unicode_decode(content):
|
||||||
|
|
@ -1529,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
|
||||||
return import_class
|
return import_class
|
||||||
|
|
||||||
|
|
||||||
|
async def update_chunk_cache_list(
|
||||||
|
chunk_id: str,
|
||||||
|
text_chunks_storage: "BaseKVStorage",
|
||||||
|
cache_keys: list[str],
|
||||||
|
cache_scenario: str = "batch_update",
|
||||||
|
) -> None:
|
||||||
|
"""Update chunk's llm_cache_list with the given cache keys
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_id: Chunk identifier
|
||||||
|
text_chunks_storage: Text chunks storage instance
|
||||||
|
cache_keys: List of cache keys to add to the list
|
||||||
|
cache_scenario: Description of the cache scenario for logging
|
||||||
|
"""
|
||||||
|
if not cache_keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
|
||||||
|
if chunk_data:
|
||||||
|
# Ensure llm_cache_list exists
|
||||||
|
if "llm_cache_list" not in chunk_data:
|
||||||
|
chunk_data["llm_cache_list"] = []
|
||||||
|
|
||||||
|
# Add cache keys to the list if not already present
|
||||||
|
existing_keys = set(chunk_data["llm_cache_list"])
|
||||||
|
new_keys = [key for key in cache_keys if key not in existing_keys]
|
||||||
|
|
||||||
|
if new_keys:
|
||||||
|
chunk_data["llm_cache_list"].extend(new_keys)
|
||||||
|
|
||||||
|
# Update the chunk in storage
|
||||||
|
await text_chunks_storage.upsert({chunk_id: chunk_data})
|
||||||
|
logger.debug(
|
||||||
|
f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def use_llm_func_with_cache(
|
async def use_llm_func_with_cache(
|
||||||
input_text: str,
|
input_text: str,
|
||||||
use_llm_func: callable,
|
use_llm_func: callable,
|
||||||
|
|
@ -1537,6 +1473,7 @@ async def use_llm_func_with_cache(
|
||||||
history_messages: list[dict[str, str]] = None,
|
history_messages: list[dict[str, str]] = None,
|
||||||
cache_type: str = "extract",
|
cache_type: str = "extract",
|
||||||
chunk_id: str | None = None,
|
chunk_id: str | None = None,
|
||||||
|
cache_keys_collector: list = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call LLM function with cache support
|
"""Call LLM function with cache support
|
||||||
|
|
||||||
|
|
@ -1551,6 +1488,8 @@ async def use_llm_func_with_cache(
|
||||||
history_messages: History messages list
|
history_messages: History messages list
|
||||||
cache_type: Type of cache
|
cache_type: Type of cache
|
||||||
chunk_id: Chunk identifier to store in cache
|
chunk_id: Chunk identifier to store in cache
|
||||||
|
text_chunks_storage: Text chunks storage to update llm_cache_list
|
||||||
|
cache_keys_collector: Optional list to collect cache keys for batch processing
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLM response text
|
LLM response text
|
||||||
|
|
@ -1563,6 +1502,9 @@ async def use_llm_func_with_cache(
|
||||||
_prompt = input_text
|
_prompt = input_text
|
||||||
|
|
||||||
arg_hash = compute_args_hash(_prompt)
|
arg_hash = compute_args_hash(_prompt)
|
||||||
|
# Generate cache key for this LLM call
|
||||||
|
cache_key = generate_cache_key("default", cache_type, arg_hash)
|
||||||
|
|
||||||
cached_return, _1, _2, _3 = await handle_cache(
|
cached_return, _1, _2, _3 = await handle_cache(
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
arg_hash,
|
arg_hash,
|
||||||
|
|
@ -1573,6 +1515,11 @@ async def use_llm_func_with_cache(
|
||||||
if cached_return:
|
if cached_return:
|
||||||
logger.debug(f"Found cache for {arg_hash}")
|
logger.debug(f"Found cache for {arg_hash}")
|
||||||
statistic_data["llm_cache"] += 1
|
statistic_data["llm_cache"] += 1
|
||||||
|
|
||||||
|
# Add cache key to collector if provided
|
||||||
|
if cache_keys_collector is not None:
|
||||||
|
cache_keys_collector.append(cache_key)
|
||||||
|
|
||||||
return cached_return
|
return cached_return
|
||||||
statistic_data["llm_call"] += 1
|
statistic_data["llm_call"] += 1
|
||||||
|
|
||||||
|
|
@ -1597,6 +1544,10 @@ async def use_llm_func_with_cache(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add cache key to collector if provided
|
||||||
|
if cache_keys_collector is not None:
|
||||||
|
cache_keys_collector.append(cache_key)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# When cache is disabled, directly call LLM
|
# When cache is disabled, directly call LLM
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, cast
|
||||||
|
|
||||||
from .base import DeletionResult
|
from .base import DeletionResult
|
||||||
from .kg.shared_storage import get_graph_db_lock
|
from .kg.shared_storage import get_graph_db_lock
|
||||||
from .prompt import GRAPH_FIELD_SEP
|
from .constants import GRAPH_FIELD_SEP
|
||||||
from .utils import compute_mdhash_id, logger
|
from .utils import compute_mdhash_id, logger
|
||||||
from .base import StorageNameSpace
|
from .base import StorageNameSpace
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,10 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
||||||
"Winner": "[Answer 1 or Answer 2]",
|
"Winner": "[Answer 1 or Answer 2]",
|
||||||
"Explanation": "[Provide explanation here]"
|
"Explanation": "[Provide explanation here]"
|
||||||
}},
|
}},
|
||||||
|
"Diversity": {{
|
||||||
|
"Winner": "[Answer 1 or Answer 2]",
|
||||||
|
"Explanation": "[Provide explanation here]"
|
||||||
|
}},
|
||||||
"Empowerment": {{
|
"Empowerment": {{
|
||||||
"Winner": "[Answer 1 or Answer 2]",
|
"Winner": "[Answer 1 or Answer 2]",
|
||||||
"Explanation": "[Provide explanation here]"
|
"Explanation": "[Provide explanation here]"
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
支持的图存储类型包括:
|
支持的图存储类型包括:
|
||||||
- NetworkXStorage
|
- NetworkXStorage
|
||||||
- Neo4JStorage
|
- Neo4JStorage
|
||||||
|
- MongoDBStorage
|
||||||
- PGGraphStorage
|
- PGGraphStorage
|
||||||
- MemgraphStorage
|
- MemgraphStorage
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue