From cb75e6631e9712c99794943b9266e44835bf233d Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 5 Aug 2025 17:58:34 +0800 Subject: [PATCH 1/8] Remove quantized embedding info from LLM cache - Delete quantize_embedding function - Delete dequantize_embedding function - Remove embedding fields from CacheData - Update save_to_cache to exclude embedding data - Clean up unused quantization-related code --- lightrag/operate.py | 12 ------------ lightrag/utils.py | 45 --------------------------------------------- 2 files changed, 57 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index ca21881b..254dfdac 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1833,9 +1833,6 @@ async def kg_query( args_hash=args_hash, content=response, prompt=query, - quantized=quantized, - min_val=min_val, - max_val=max_val, mode=query_param.mode, cache_type="query", ), @@ -1972,9 +1969,6 @@ async def extract_keywords_only( args_hash=args_hash, content=json.dumps(cache_data), prompt=text, - quantized=quantized, - min_val=min_val, - max_val=max_val, mode=param.mode, cache_type="keywords", ), @@ -3105,9 +3099,6 @@ async def naive_query( args_hash=args_hash, content=response, prompt=query, - quantized=quantized, - min_val=min_val, - max_val=max_val, mode=query_param.mode, cache_type="query", ), @@ -3231,9 +3222,6 @@ async def kg_query_with_keywords( args_hash=args_hash, content=response, prompt=query, - quantized=quantized, - min_val=min_val, - max_val=max_val, mode=query_param.mode, cache_type="query", ), diff --git a/lightrag/utils.py b/lightrag/utils.py index 354ca0a3..9e818d6b 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -756,40 +756,6 @@ def cosine_similarity(v1, v2): return dot_product / (norm1 * norm2) -def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple: - """Quantize embedding to specified bits""" - # Convert list to numpy array if needed - if isinstance(embedding, list): - embedding = np.array(embedding) - - # Calculate min/max values for reconstruction - min_val = embedding.min() - max_val = embedding.max() - - if min_val == max_val: - # handle constant vector - quantized = np.zeros_like(embedding, dtype=np.uint8) - return quantized, min_val, max_val - - # Quantize to 0-255 range - scale = (2**bits - 1) / (max_val - min_val) - quantized = np.round((embedding - min_val) * scale).astype(np.uint8) - - return quantized, min_val, max_val - - -def dequantize_embedding( - quantized: np.ndarray, min_val: float, max_val: float, bits=8 -) -> np.ndarray: - """Restore quantized embedding""" - if min_val == max_val: - # handle constant vector - return np.full_like(quantized, min_val, dtype=np.float32) - - scale = (max_val - min_val) / (2**bits - 1) - return (quantized * scale + min_val).astype(np.float32) - - async def handle_cache( hashing_kv, args_hash, @@ -824,9 +790,6 @@ class CacheData: args_hash: str content: str prompt: str - quantized: np.ndarray | None = None - min_val: float | None = None - max_val: float | None = None mode: str = "default" cache_type: str = "query" chunk_id: str | None = None @@ -866,14 +829,6 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "return": cache_data.content, "cache_type": cache_data.cache_type, "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, - "embedding": cache_data.quantized.tobytes().hex() - if cache_data.quantized is not None - else None, - "embedding_shape": cache_data.quantized.shape - if cache_data.quantized is not None - else None, - "embedding_min": cache_data.min_val, - "embedding_max": cache_data.max_val, "original_prompt": cache_data.prompt, } From 0463963520dad23ea1df4b7ef1d2d1da2f70b6c1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 5 Aug 2025 18:03:10 +0800 Subject: [PATCH 2/8] fix: include all query parameters in LLM cache hash key generation - Add missing query parameters (top_k, enable_rerank, max_tokens, etc.) to cache key generation in kg_query, naive_query, and extract_keywords_only functions - Add queryparam field to CacheData structure and PostgreSQL storage for debugging - Update PostgreSQL schema with automatic migration for queryparam JSONB column - Prevent incorrect cache hits between queries with different parameters Fixes issue where different query parameters incorrectly shared the same cached results. --- lightrag/kg/postgres_impl.py | 51 +++++++++++++++++--- lightrag/operate.py | 92 ++++++++++++++++++++++++++++++++++-- lightrag/utils.py | 4 ++ 3 files changed, 135 insertions(+), 12 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 66f0dd6c..06079d71 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -225,14 +225,14 @@ class PostgreSQLDB: pass async def _migrate_llm_cache_add_columns(self): - """Add chunk_id and cache_type columns to LIGHTRAG_LLM_CACHE table if they don't exist""" + """Add chunk_id, cache_type, and queryparam columns to LIGHTRAG_LLM_CACHE table if they don't exist""" try: - # Check if both columns exist + # Check if all columns exist check_columns_sql = """ SELECT column_name FROM information_schema.columns WHERE table_name = 'lightrag_llm_cache' - AND column_name IN ('chunk_id', 'cache_type') + AND column_name IN ('chunk_id', 'cache_type', 'queryparam') """ existing_columns = await self.query(check_columns_sql, multirows=True) @@ -289,6 +289,22 @@ class PostgreSQLDB: "cache_type column already exists in LIGHTRAG_LLM_CACHE table" ) + # Add missing queryparam column + if "queryparam" not in existing_column_names: + logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table") + add_queryparam_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + ADD COLUMN queryparam JSONB NULL + """ + await self.execute(add_queryparam_sql) + logger.info( + "Successfully added queryparam column to LIGHTRAG_LLM_CACHE table" + ) + else: + logger.info( + "queryparam column already exists in LIGHTRAG_LLM_CACHE table" + ) + except Exception as e: logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}") @@ -1379,6 +1395,13 @@ class PGKVStorage(BaseKVStorage): ): create_time = response.get("create_time", 0) update_time = response.get("update_time", 0) + # Parse queryparam JSON string back to dict + queryparam = response.get("queryparam") + if isinstance(queryparam, str): + try: + queryparam = json.loads(queryparam) + except json.JSONDecodeError: + queryparam = None # Map field names and add cache_type for compatibility response = { **response, @@ -1387,6 +1410,7 @@ class PGKVStorage(BaseKVStorage): "original_prompt": response.get("original_prompt", ""), "chunk_id": response.get("chunk_id"), "mode": response.get("mode", "default"), + "queryparam": queryparam, "create_time": create_time, "update_time": create_time if update_time == 0 else update_time, } @@ -1455,6 +1479,13 @@ class PGKVStorage(BaseKVStorage): for row in results: create_time = row.get("create_time", 0) update_time = row.get("update_time", 0) + # Parse queryparam JSON string back to dict + queryparam = row.get("queryparam") + if isinstance(queryparam, str): + try: + queryparam = json.loads(queryparam) + except json.JSONDecodeError: + queryparam = None # Map field names and add cache_type for compatibility processed_row = { **row, @@ -1463,6 +1494,7 @@ class PGKVStorage(BaseKVStorage): "original_prompt": row.get("original_prompt", ""), "chunk_id": row.get("chunk_id"), "mode": row.get("mode", "default"), + "queryparam": queryparam, "create_time": create_time, "update_time": create_time if update_time == 0 else update_time, } @@ -1570,6 +1602,9 @@ class PGKVStorage(BaseKVStorage): "cache_type": v.get( "cache_type", "extract" ), # Get cache_type from data + "queryparam": json.dumps(v.get("queryparam")) + if v.get("queryparam") + else None, } await self.db.execute(upsert_sql, _data) @@ -4054,6 +4089,7 @@ TABLES = { return_value TEXT, chunk_id VARCHAR(255) NULL, cache_type VARCHAR(32), + queryparam JSONB NULL, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) @@ -4114,7 +4150,7 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 @@ -4132,7 +4168,7 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) @@ -4163,14 +4199,15 @@ SQL_TEMPLATES = { ON CONFLICT (workspace,id) DO UPDATE SET content = $2, update_time = CURRENT_TIMESTAMP """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type) - VALUES ($1, $2, $3, $4, $5, $6, $7) + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type,queryparam) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (workspace,mode,id) DO UPDATE SET original_prompt = EXCLUDED.original_prompt, return_value=EXCLUDED.return_value, mode=EXCLUDED.mode, chunk_id=EXCLUDED.chunk_id, cache_type=EXCLUDED.cache_type, + queryparam=EXCLUDED.queryparam, update_time = CURRENT_TIMESTAMP """, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, diff --git a/lightrag/operate.py b/lightrag/operate.py index 254dfdac..1d398ad3 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1727,7 +1727,20 @@ async def kg_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query) + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, + ) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -1826,7 +1839,20 @@ async def kg_query( ) if hashing_kv.global_config.get("enable_llm_cache"): - # Save to cache + # Save to cache with query parameters + queryparam_dict = { + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "hl_keywords": query_param.hl_keywords or [], + "ll_keywords": query_param.ll_keywords or [], + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, + } await save_to_cache( hashing_kv, CacheData( @@ -1835,6 +1861,7 @@ async def kg_query( prompt=query, mode=query_param.mode, cache_type="query", + queryparam=queryparam_dict, ), ) @@ -1886,7 +1913,20 @@ async def extract_keywords_only( """ # 1. Handle cache if needed - add cache type for keywords - args_hash = compute_args_hash(param.mode, text) + args_hash = compute_args_hash( + param.mode, + text, + param.response_type, + param.top_k, + param.chunk_top_k, + param.max_entity_tokens, + param.max_relation_tokens, + param.max_total_tokens, + param.hl_keywords or [], + param.ll_keywords or [], + param.user_prompt or "", + param.enable_rerank, + ) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) @@ -1963,6 +2003,20 @@ async def extract_keywords_only( "low_level_keywords": ll_keywords, } if hashing_kv.global_config.get("enable_llm_cache"): + # Save to cache with query parameters + queryparam_dict = { + "mode": param.mode, + "response_type": param.response_type, + "top_k": param.top_k, + "chunk_top_k": param.chunk_top_k, + "max_entity_tokens": param.max_entity_tokens, + "max_relation_tokens": param.max_relation_tokens, + "max_total_tokens": param.max_total_tokens, + "hl_keywords": param.hl_keywords or [], + "ll_keywords": param.ll_keywords or [], + "user_prompt": param.user_prompt or "", + "enable_rerank": param.enable_rerank, + } await save_to_cache( hashing_kv, CacheData( @@ -1971,6 +2025,7 @@ async def extract_keywords_only( prompt=text, mode=param.mode, cache_type="keywords", + queryparam=queryparam_dict, ), ) @@ -2945,7 +3000,20 @@ async def naive_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query) + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, + ) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -3092,7 +3160,20 @@ async def naive_query( ) if hashing_kv.global_config.get("enable_llm_cache"): - # Save to cache + # Save to cache with query parameters + queryparam_dict = { + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "hl_keywords": query_param.hl_keywords or [], + "ll_keywords": query_param.ll_keywords or [], + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, + } await save_to_cache( hashing_kv, CacheData( @@ -3101,6 +3182,7 @@ async def naive_query( prompt=query, mode=query_param.mode, cache_type="query", + queryparam=queryparam_dict, ), ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 9e818d6b..1eb8c98d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -793,6 +793,7 @@ class CacheData: mode: str = "default" cache_type: str = "query" chunk_id: str | None = None + queryparam: dict | None = None async def save_to_cache(hashing_kv, cache_data: CacheData): @@ -830,6 +831,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "cache_type": cache_data.cache_type, "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, "original_prompt": cache_data.prompt, + "queryparam": cache_data.queryparam + if cache_data.queryparam is not None + else None, } logger.info(f" == LLM cache == saving: {flattened_key}") From 0b5c708660572a3df561c06e67af911392335e44 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 5 Aug 2025 18:03:51 +0800 Subject: [PATCH 3/8] Update storage implementation documentation - Add detailed storage type descriptions - Remove Chroma from vector storage options - Include recommended PostgreSQL version - Add Memgraph to graph storage options - Update performance comparison notes --- README-zh.md | 49 +++++++++++++++++++++++++++++++++++++- README.md | 50 ++++++++++++++++++++++++++++++++++++++- lightrag/api/README-zh.md | 1 - lightrag/api/README.md | 1 - 4 files changed, 97 insertions(+), 4 deletions(-) diff --git a/README-zh.md b/README-zh.md index 2d08ff51..4c74f6e2 100644 --- a/README-zh.md +++ b/README-zh.md @@ -737,7 +737,54 @@ rag.insert(documents, file_paths=file_paths) ### 存储 -LightRAG使用到4种类型的存储,每一种存储都有多种实现方案。在初始化LightRAG的时候可以通过参数设定这四类存储的实现方案。详情请参看前面的LightRAG初始化参数。 +LightRAG 使用 4 种类型的存储用于不同目的: + +* KV_STORAGE:llm 响应缓存、文本块、文档信息 +* VECTOR_STORAGE:实体向量、关系向量、块向量 +* GRAPH_STORAGE:实体关系图 +* DOC_STATUS_STORAGE:文档索引状态 + +每种存储类型都有几种实现: + +* KV_STORAGE 支持的实现名称 + +``` +JsonKVStorage JsonFile(默认) +PGKVStorage Postgres +RedisKVStorage Redis +MongoKVStorage MogonDB +``` + +* GRAPH_STORAGE 支持的实现名称 + +``` +NetworkXStorage NetworkX(默认) +Neo4JStorage Neo4J +PGGraphStorage PostgreSQL with AGE plugin +``` + +> 在测试中Neo4j图形数据库相比PostgreSQL AGE有更好的性能表现。 + +* VECTOR_STORAGE 支持的实现名称 + +``` +NanoVectorDBStorage NanoVector(默认) +PGVectorStorage Postgres +MilvusVectorDBStorge Milvus +FaissVectorDBStorage Faiss +QdrantVectorDBStorage Qdrant +MongoVectorDBStorage MongoDB +``` + +* DOC_STATUS_STORAGE 支持的实现名称 + +``` +JsonDocStatusStorage JsonFile(默认) +PGDocStatusStorage Postgres +MongoDocStatusStorage MongoDB +``` + +每一种存储类型的链接配置范例可以在 `env.example` 文件中找到。链接字符串中的数据库实例是需要你预先在数据库服务器上创建好的,LightRAG 仅负责在数据库实例中创建数据表,不负责创建数据库实例。如果使用 Redis 作为存储,记得给 Redis 配置自动持久化数据规则,否则 Redis 服务重启后数据会丢失。如果使用PostgreSQL数据库,推荐使用16.6版本或以上。
使用Neo4J进行存储 diff --git a/README.md b/README.md index bf319763..ce606554 100644 --- a/README.md +++ b/README.md @@ -747,7 +747,55 @@ rag.insert(documents, file_paths=file_paths) ### Storage -LightRAG uses four types of storage, each of which has multiple implementation options. When initializing LightRAG, the implementation schemes for these four types of storage can be set through parameters. For details, please refer to the previous LightRAG initialization parameters. +LightRAG uses 4 types of storage for different purposes: + +* KV_STORAGE: llm response cache, text chunks, document information +* VECTOR_STORAGE: entities vectors, relation vectors, chunks vectors +* GRAPH_STORAGE: entity relation graph +* DOC_STATUS_STORAGE: document indexing status + +Each storage type has several implementations: + +* KV_STORAGE supported implementations: + +``` +JsonKVStorage JsonFile (default) +PGKVStorage Postgres +RedisKVStorage Redis +MongoKVStorage MongoDB +``` + +* GRAPH_STORAGE supported implementations: + +``` +NetworkXStorage NetworkX (default) +Neo4JStorage Neo4J +PGGraphStorage PostgreSQL with AGE plugin +MemgraphStorage. Memgraph +``` + +> Testing has shown that Neo4J delivers superior performance in production environments compared to PostgreSQL with AGE plugin. + +* VECTOR_STORAGE supported implementations: + +``` +NanoVectorDBStorage NanoVector (default) +PGVectorStorage Postgres +MilvusVectorDBStorage Milvus +FaissVectorDBStorage Faiss +QdrantVectorDBStorage Qdrant +MongoVectorDBStorage MongoDB +``` + +* DOC_STATUS_STORAGE: supported implementations: + +``` +JsonDocStatusStorage JsonFile (default) +PGDocStatusStorage Postgres +MongoDocStatusStorage MongoDB +``` + +Example connection configurations for each storage type can be found in the `env.example` file. The database instance in the connection string needs to be created by you on the database server beforehand. LightRAG is only responsible for creating tables within the database instance, not for creating the database instance itself. If using Redis as storage, remember to configure automatic data persistence rules for Redis, otherwise data will be lost after the Redis service restarts. If using PostgreSQL, it is recommended to use version 16.6 or above.
Using Neo4J for Storage diff --git a/lightrag/api/README-zh.md b/lightrag/api/README-zh.md index b80419e1..6fe5f86c 100644 --- a/lightrag/api/README-zh.md +++ b/lightrag/api/README-zh.md @@ -409,7 +409,6 @@ PGGraphStorage PostgreSQL with AGE plugin NanoVectorDBStorage NanoVector(默认) PGVectorStorage Postgres MilvusVectorDBStorge Milvus -ChromaVectorDBStorage Chroma FaissVectorDBStorage Faiss QdrantVectorDBStorage Qdrant MongoVectorDBStorage MongoDB diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 6e0a59de..ce27baff 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -412,7 +412,6 @@ MemgraphStorage. Memgraph NanoVectorDBStorage NanoVector (default) PGVectorStorage Postgres MilvusVectorDBStorage Milvus -ChromaVectorDBStorage Chroma FaissVectorDBStorage Faiss QdrantVectorDBStorage Qdrant MongoVectorDBStorage MongoDB From 8294d6d1b775d6adefe41eff44a090a6cf8eca4f Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 5 Aug 2025 23:18:54 +0800 Subject: [PATCH 4/8] Remove deprecated mode field from LLM cache schema - Drop mode column from LLM cache table - Update primary key to exclude mode - Remove mode from all SQL queries - Deprecate mode-related methods - Update schema migration logic --- lightrag/kg/postgres_impl.py | 99 ++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 06079d71..3cc0ac9a 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -224,15 +224,15 @@ class PostgreSQLDB: ): pass - async def _migrate_llm_cache_add_columns(self): - """Add chunk_id, cache_type, and queryparam columns to LIGHTRAG_LLM_CACHE table if they don't exist""" + async def _migrate_llm_cache_schema(self): + """Migrate LLM cache schema: add new columns and remove deprecated mode field""" try: # Check if all columns exist check_columns_sql = """ SELECT column_name FROM information_schema.columns WHERE table_name = 'lightrag_llm_cache' - AND column_name IN ('chunk_id', 'cache_type', 'queryparam') + AND column_name IN ('chunk_id', 'cache_type', 'queryparam', 'mode') """ existing_columns = await self.query(check_columns_sql, multirows=True) @@ -305,8 +305,38 @@ class PostgreSQLDB: "queryparam column already exists in LIGHTRAG_LLM_CACHE table" ) + # Remove deprecated mode field if it exists + if "mode" in existing_column_names: + logger.info("Removing deprecated mode column from LIGHTRAG_LLM_CACHE table") + + # First, drop the primary key constraint that includes mode + drop_pk_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + DROP CONSTRAINT IF EXISTS LIGHTRAG_LLM_CACHE_PK + """ + await self.execute(drop_pk_sql) + logger.info("Dropped old primary key constraint") + + # Drop the mode column + drop_mode_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + DROP COLUMN mode + """ + await self.execute(drop_mode_sql) + logger.info("Successfully removed mode column from LIGHTRAG_LLM_CACHE table") + + # Create new primary key constraint without mode + add_pk_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + ADD CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id) + """ + await self.execute(add_pk_sql) + logger.info("Created new primary key constraint (workspace, id)") + else: + logger.info("mode column does not exist in LIGHTRAG_LLM_CACHE table") + except Exception as e: - logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}") + logger.warning(f"Failed to migrate LLM cache schema: {e}") async def _migrate_timestamp_columns(self): """Migrate timestamp columns in tables to witimezone-free types, assuming original data is in UTC time""" @@ -872,11 +902,11 @@ class PostgreSQLDB: logger.error(f"PostgreSQL, Failed to migrate timestamp columns: {e}") # Don't throw an exception, allow the initialization process to continue - # Migrate LLM cache table to add chunk_id and cache_type columns if needed + # Migrate LLM cache schema: add new columns and remove deprecated mode field try: - await self._migrate_llm_cache_add_columns() + await self._migrate_llm_cache_schema() except Exception as e: - logger.error(f"PostgreSQL, Failed to migrate LLM cache columns: {e}") + logger.error(f"PostgreSQL, Failed to migrate LLM cache schema: {e}") # Don't throw an exception, allow the initialization process to continue # Finally, attempt to migrate old doc chunks data if needed @@ -1402,14 +1432,13 @@ class PGKVStorage(BaseKVStorage): queryparam = json.loads(queryparam) except json.JSONDecodeError: queryparam = None - # Map field names and add cache_type for compatibility + # Map field names for compatibility (mode field removed) response = { **response, "return": response.get("return_value", ""), "cache_type": response.get("cache_type"), "original_prompt": response.get("original_prompt", ""), "chunk_id": response.get("chunk_id"), - "mode": response.get("mode", "default"), "queryparam": queryparam, "create_time": create_time, "update_time": create_time if update_time == 0 else update_time, @@ -1486,14 +1515,13 @@ class PGKVStorage(BaseKVStorage): queryparam = json.loads(queryparam) except json.JSONDecodeError: queryparam = None - # Map field names and add cache_type for compatibility + # Map field names for compatibility (mode field removed) processed_row = { **row, "return": row.get("return_value", ""), "cache_type": row.get("cache_type"), "original_prompt": row.get("original_prompt", ""), "chunk_id": row.get("chunk_id"), - "mode": row.get("mode", "default"), "queryparam": queryparam, "create_time": create_time, "update_time": create_time if update_time == 0 else update_time, @@ -1597,7 +1625,6 @@ class PGKVStorage(BaseKVStorage): "id": k, # Use flattened key as id "original_prompt": v["original_prompt"], "return_value": v["return"], - "mode": v.get("mode", "default"), # Get mode from data "chunk_id": v.get("chunk_id"), "cache_type": v.get( "cache_type", "extract" @@ -1671,37 +1698,16 @@ class PGKVStorage(BaseKVStorage): logger.error(f"Error while deleting records from {self.namespace}: {e}") async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode + """Delete specific records from storage by cache mode (deprecated - mode field removed) Args: - modes (list[str]): List of cache modes to be dropped from storage + modes (list[str]): List of cache modes (deprecated, no longer used) Returns: - bool: True if successful, False otherwise + bool: False (method deprecated due to mode field removal) """ - if not modes: - return False - - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return False - - if table_name != "LIGHTRAG_LLM_CACHE": - return False - - sql = f""" - DELETE FROM {table_name} - WHERE workspace = $1 AND mode = ANY($2) - """ - params = {"workspace": self.db.workspace, "modes": modes} - - logger.info(f"Deleting cache by modes: {modes}") - await self.db.execute(sql, params) - return True - except Exception as e: - logger.error(f"Error deleting cache by modes {modes}: {e}") - return False + logger.warning("drop_cache_by_modes is deprecated: mode field has been removed from LLM cache") + return False async def drop(self) -> dict[str, str]: """Drop the storage""" @@ -4084,7 +4090,6 @@ TABLES = { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( workspace varchar(255) NOT NULL, id varchar(255) NOT NULL, - mode varchar(32) NOT NULL, original_prompt TEXT, return_value TEXT, chunk_id VARCHAR(255) NULL, @@ -4092,7 +4097,7 @@ TABLES = { queryparam JSONB NULL, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) + CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id) )""" }, "LIGHTRAG_DOC_STATUS": { @@ -4150,14 +4155,11 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam, + "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 """, - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 - """, "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) """, @@ -4168,7 +4170,7 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam, + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) @@ -4199,12 +4201,11 @@ SQL_TEMPLATES = { ON CONFLICT (workspace,id) DO UPDATE SET content = $2, update_time = CURRENT_TIMESTAMP """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type,queryparam) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT (workspace,mode,id) DO UPDATE + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,chunk_id,cache_type,queryparam) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (workspace,id) DO UPDATE SET original_prompt = EXCLUDED.original_prompt, return_value=EXCLUDED.return_value, - mode=EXCLUDED.mode, chunk_id=EXCLUDED.chunk_id, cache_type=EXCLUDED.cache_type, queryparam=EXCLUDED.queryparam, From cc1f7118e7ba9f015e0d5a9a4d74b7a75c6b3f3a Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 5 Aug 2025 23:20:26 +0800 Subject: [PATCH 5/8] Remove deprecated cache_by_modes functionality from all storage --- lightrag/base.py | 19 ------ lightrag/kg/deprecated/tidb_impl.py | 35 ----------- lightrag/kg/json_kv_impl.py | 90 ----------------------------- lightrag/kg/mongo_impl.py | 22 ------- lightrag/kg/postgres_impl.py | 26 +++------ lightrag/kg/redis_impl.py | 60 ------------------- 6 files changed, 9 insertions(+), 243 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 9e88fb12..0e651f7b 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -331,21 +331,6 @@ class BaseKVStorage(StorageNameSpace, ABC): None """ - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode - - Importance notes for in-memory storage: - 1. Changes will be persisted to disk during the next index_done_callback - 2. update flags to notify other processes that data persistence is needed - - Args: - modes (list[str]): List of cache modes to be dropped from storage - - Returns: - True: if the cache drop successfully - False: if the cache drop failed, or the cache mode is not supported - """ - @dataclass class BaseGraphStorage(StorageNameSpace, ABC): @@ -761,10 +746,6 @@ class DocStatusStorage(BaseKVStorage, ABC): Dictionary mapping status names to counts """ - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Drop cache is not supported for Doc Status storage""" - return False - class StoragesStatus(str, Enum): """Storages status""" diff --git a/lightrag/kg/deprecated/tidb_impl.py b/lightrag/kg/deprecated/tidb_impl.py index d60bb1f6..0d5dfca3 100644 --- a/lightrag/kg/deprecated/tidb_impl.py +++ b/lightrag/kg/deprecated/tidb_impl.py @@ -347,41 +347,6 @@ class TiDBKVStorage(BaseKVStorage): except Exception as e: logger.error(f"Error deleting records from {self.namespace}: {e}") - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode - - Args: - modes (list[str]): List of cache modes to be dropped from storage - - Returns: - bool: True if successful, False otherwise - """ - if not modes: - return False - - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return False - - if table_name != "LIGHTRAG_LLM_CACHE": - return False - - # Build MySQL style IN query - modes_list = ", ".join([f"'{mode}'" for mode in modes]) - sql = f""" - DELETE FROM {table_name} - WHERE workspace = :workspace - AND mode IN ({modes_list}) - """ - - logger.info(f"Deleting cache by modes: {modes}") - await self.db.execute(sql, {"workspace": self.db.workspace}) - return True - except Exception as e: - logger.error(f"Error deleting cache by modes {modes}: {e}") - return False - async def drop(self) -> dict[str, str]: """Drop the storage""" try: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 70a265fe..d6d80079 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -195,96 +195,6 @@ class JsonKVStorage(BaseKVStorage): if any_deleted: await set_all_update_flags(self.final_namespace) - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode - - Importance notes for in-memory storage: - 1. Changes will be persisted to disk during the next index_done_callback - 2. update flags to notify other processes that data persistence is needed - - Args: - modes (list[str]): List of cache modes to be dropped from storage - - Returns: - True: if the cache drop successfully - False: if the cache drop failed - """ - if not modes: - return False - - try: - async with self._storage_lock: - keys_to_delete = [] - modes_set = set(modes) # Convert to set for efficient lookup - - for key in list(self._data.keys()): - # Parse flattened cache key: mode:cache_type:hash - parts = key.split(":", 2) - if len(parts) == 3 and parts[0] in modes_set: - keys_to_delete.append(key) - - # Batch delete - for key in keys_to_delete: - self._data.pop(key, None) - - if keys_to_delete: - await set_all_update_flags(self.final_namespace) - logger.info( - f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}" - ) - - return True - except Exception as e: - logger.error(f"Error dropping cache by modes: {e}") - return False - - # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool: - # """Delete specific cache records from storage by chunk IDs - - # Importance notes for in-memory storage: - # 1. Changes will be persisted to disk during the next index_done_callback - # 2. update flags to notify other processes that data persistence is needed - - # Args: - # chunk_ids (list[str]): List of chunk IDs to be dropped from storage - - # Returns: - # True: if the cache drop successfully - # False: if the cache drop failed - # """ - # if not chunk_ids: - # return False - - # try: - # async with self._storage_lock: - # # Iterate through all cache modes to find entries with matching chunk_ids - # for mode_key, mode_data in list(self._data.items()): - # if isinstance(mode_data, dict): - # # Check each cached entry in this mode - # for cache_key, cache_entry in list(mode_data.items()): - # if ( - # isinstance(cache_entry, dict) - # and cache_entry.get("chunk_id") in chunk_ids - # ): - # # Remove this cache entry - # del mode_data[cache_key] - # logger.debug( - # f"Removed cache entry {cache_key} for chunk {cache_entry.get('chunk_id')}" - # ) - - # # If the mode is now empty, remove it entirely - # if not mode_data: - # del self._data[mode_key] - - # # Set update flags to notify persistence is needed - # await set_all_update_flags(self.final_namespace) - - # logger.info(f"Cleared cache for {len(chunk_ids)} chunk IDs") - # return True - # except Exception as e: - # logger.error(f"Error clearing cache by chunk IDs: {e}") - # return False - async def drop(self) -> dict[str, str]: """Drop all data from storage and clean up resources This action will persistent the data to disk immediately. diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 9e2847f2..64622127 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -232,28 +232,6 @@ class MongoKVStorage(BaseKVStorage): except PyMongoError as e: logger.error(f"Error deleting documents from {self.namespace}: {e}") - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode - - Args: - modes (list[str]): List of cache modes to be dropped from storage - - Returns: - bool: True if successful, False otherwise - """ - if not modes: - return False - - try: - # Build regex pattern to match flattened key format: mode:cache_type:hash - pattern = f"^({'|'.join(modes)}):" - result = await self._data.delete_many({"_id": {"$regex": pattern}}) - logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}") - return True - except Exception as e: - logger.error(f"Error deleting cache by modes {modes}: {e}") - return False - async def drop(self) -> dict[str, str]: """Drop the storage by removing all documents in the collection. diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 3cc0ac9a..804454f7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -307,8 +307,10 @@ class PostgreSQLDB: # Remove deprecated mode field if it exists if "mode" in existing_column_names: - logger.info("Removing deprecated mode column from LIGHTRAG_LLM_CACHE table") - + logger.info( + "Removing deprecated mode column from LIGHTRAG_LLM_CACHE table" + ) + # First, drop the primary key constraint that includes mode drop_pk_sql = """ ALTER TABLE LIGHTRAG_LLM_CACHE @@ -316,15 +318,17 @@ class PostgreSQLDB: """ await self.execute(drop_pk_sql) logger.info("Dropped old primary key constraint") - + # Drop the mode column drop_mode_sql = """ ALTER TABLE LIGHTRAG_LLM_CACHE DROP COLUMN mode """ await self.execute(drop_mode_sql) - logger.info("Successfully removed mode column from LIGHTRAG_LLM_CACHE table") - + logger.info( + "Successfully removed mode column from LIGHTRAG_LLM_CACHE table" + ) + # Create new primary key constraint without mode add_pk_sql = """ ALTER TABLE LIGHTRAG_LLM_CACHE @@ -1697,18 +1701,6 @@ class PGKVStorage(BaseKVStorage): except Exception as e: logger.error(f"Error while deleting records from {self.namespace}: {e}") - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode (deprecated - mode field removed) - - Args: - modes (list[str]): List of cache modes (deprecated, no longer used) - - Returns: - bool: False (method deprecated due to mode field removal) - """ - logger.warning("drop_cache_by_modes is deprecated: mode field has been removed from LLM cache") - return False - async def drop(self) -> dict[str, str]: """Drop the storage""" try: diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index ae18242f..1c8d3c68 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -397,66 +397,6 @@ class RedisKVStorage(BaseKVStorage): f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" ) - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode - - Importance notes for Redis storage: - 1. This will immediately delete the specified cache modes from Redis - - Args: - modes (list[str]): List of cache modes to be dropped from storage - - Returns: - True: if the cache drop successfully - False: if the cache drop failed - """ - if not modes: - return False - - try: - async with self._get_redis_connection() as redis: - keys_to_delete = [] - - # Find matching keys for each mode using SCAN - for mode in modes: - # Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash} - pattern = f"{self.namespace}:{mode}:*" - cursor = 0 - mode_keys = [] - - while True: - cursor, keys = await redis.scan( - cursor, match=pattern, count=1000 - ) - if keys: - mode_keys.extend(keys) - - if cursor == 0: - break - - keys_to_delete.extend(mode_keys) - logger.info( - f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'" - ) - - if keys_to_delete: - # Batch delete - pipe = redis.pipeline() - for key in keys_to_delete: - pipe.delete(key) - results = await pipe.execute() - deleted_count = sum(results) - logger.info( - f"Dropped {deleted_count} cache entries for modes: {modes}" - ) - else: - logger.warning(f"No cache entries found for modes: {modes}") - - return True - except Exception as e: - logger.error(f"Error dropping cache by modes in Redis: {e}") - return False - async def drop(self) -> dict[str, str]: """Drop the storage by removing all keys under the current namespace. From c22315ea6daecd87f6840e2b623e925a45b89702 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 5 Aug 2025 23:51:51 +0800 Subject: [PATCH 6/8] refactor: remove selective LLM cache clearing functionality - Remove optional 'modes' parameter from aclear_cache() and clear_cache() methods - Replace deprecated drop_cache_by_modes() with drop() method for complete cache clearing - Update API endpoint to ignore mode-specific parameters and clear all cache - Simplify frontend clearCache() function to send empty request body This change ensures all LLM cache is cleared together. --- lightrag/api/routers/document_routes.py | 48 ++++++------------------- lightrag/lightrag.py | 43 ++++++---------------- lightrag_webui/src/api/lightrag.ts | 4 +-- 3 files changed, 23 insertions(+), 72 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 6cc33fa5..84eaed4d 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -258,19 +258,12 @@ class ClearDocumentsResponse(BaseModel): class ClearCacheRequest(BaseModel): """Request model for clearing cache - Attributes: - modes: Optional list of cache modes to clear + This model is kept for API compatibility but no longer accepts any parameters. + All cache will be cleared regardless of the request content. """ - modes: Optional[ - List[Literal["default", "naive", "local", "global", "hybrid", "mix"]] - ] = Field( - default=None, - description="Modes of cache to clear. If None, clears all cache.", - ) - class Config: - json_schema_extra = {"example": {"modes": ["default", "naive"]}} + json_schema_extra = {"example": {}} class ClearCacheResponse(BaseModel): @@ -1820,47 +1813,28 @@ def create_document_routes( ) async def clear_cache(request: ClearCacheRequest): """ - Clear cache data from the LLM response cache storage. + Clear all cache data from the LLM response cache storage. - This endpoint allows clearing specific modes of cache or all cache if no modes are specified. - Valid modes include: "default", "naive", "local", "global", "hybrid", "mix". - - "default" represents extraction cache. - - Other modes correspond to different query modes. + This endpoint clears all cached LLM responses regardless of mode. + The request body is accepted for API compatibility but is ignored. Args: - request (ClearCacheRequest): The request body containing optional modes to clear. + request (ClearCacheRequest): The request body (ignored for compatibility). Returns: ClearCacheResponse: A response object containing the status and message. Raises: - HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors). + HTTPException: If an error occurs during cache clearing (500). """ try: - # Validate modes if provided - valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"] - if request.modes and not all(mode in valid_modes for mode in request.modes): - invalid_modes = [ - mode for mode in request.modes if mode not in valid_modes - ] - raise HTTPException( - status_code=400, - detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}", - ) - - # Call the aclear_cache method - await rag.aclear_cache(request.modes) + # Call the aclear_cache method (no modes parameter) + await rag.aclear_cache() # Prepare success message - if request.modes: - message = f"Successfully cleared cache for modes: {request.modes}" - else: - message = "Successfully cleared all cache" + message = "Successfully cleared all cache" return ClearCacheResponse(status="success", message=message) - except HTTPException: - # Re-raise HTTP exceptions - raise except Exception as e: logger.error(f"Error clearing cache: {str(e)}") logger.error(traceback.format_exc()) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5fae4226..467265f0 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1915,58 +1915,35 @@ class LightRAG: async def _query_done(self): await self.llm_response_cache.index_done_callback() - async def aclear_cache(self, modes: list[str] | None = None) -> None: - """Clear cache data from the LLM response cache storage. + async def aclear_cache(self) -> None: + """Clear all cache data from the LLM response cache storage. - Args: - modes (list[str] | None): Modes of cache to clear. Options: ["default", "naive", "local", "global", "hybrid", "mix"]. - "default" represents extraction cache. - If None, clears all cache. + This method clears all cached LLM responses regardless of mode. Example: # Clear all cache await rag.aclear_cache() - - # Clear local mode cache - await rag.aclear_cache(modes=["local"]) - - # Clear extraction cache - await rag.aclear_cache(modes=["default"]) """ if not self.llm_response_cache: logger.warning("No cache storage configured") return - valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"] - - # Validate input - if modes and not all(mode in valid_modes for mode in modes): - raise ValueError(f"Invalid mode. Valid modes are: {valid_modes}") - try: - # Reset the cache storage for specified mode - if modes: - success = await self.llm_response_cache.drop_cache_by_modes(modes) - if success: - logger.info(f"Cleared cache for modes: {modes}") - else: - logger.warning(f"Failed to clear cache for modes: {modes}") + # Clear all cache using drop method + success = await self.llm_response_cache.drop() + if success: + logger.info("Cleared all cache") else: - # Clear all modes - success = await self.llm_response_cache.drop_cache_by_modes(valid_modes) - if success: - logger.info("Cleared all cache") - else: - logger.warning("Failed to clear all cache") + logger.warning("Failed to clear all cache") await self.llm_response_cache.index_done_callback() except Exception as e: logger.error(f"Error while clearing cache: {e}") - def clear_cache(self, modes: list[str] | None = None) -> None: + def clear_cache(self) -> None: """Synchronous version of aclear_cache.""" - return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes)) + return always_get_an_event_loop().run_until_complete(self.aclear_cache()) async def get_docs_by_status( self, status: DocStatus diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index b2bf1bf5..98fc59ca 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -586,11 +586,11 @@ export const clearDocuments = async (): Promise => { return response.data } -export const clearCache = async (modes?: string[]): Promise<{ +export const clearCache = async (): Promise<{ status: 'success' | 'fail' message: string }> => { - const response = await axiosInstance.post('/documents/clear_cache', { modes }) + const response = await axiosInstance.post('/documents/clear_cache', {}) return response.data } From a04c11a59828e4655d669d20194851e7ed34f136 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 6 Aug 2025 00:02:50 +0800 Subject: [PATCH 7/8] Remove deprecated storage --- lightrag/kg/__init__.py | 13 - lightrag/kg/deprecated/age_impl.py | 867 ----------------- lightrag/kg/deprecated/gremlin_impl.py | 686 ------------- lightrag/kg/deprecated/tidb_impl.py | 1230 ------------------------ 4 files changed, 2796 deletions(-) delete mode 100644 lightrag/kg/deprecated/age_impl.py delete mode 100644 lightrag/kg/deprecated/gremlin_impl.py delete mode 100644 lightrag/kg/deprecated/tidb_impl.py diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index b2a93e82..8d42441a 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -5,7 +5,6 @@ STORAGE_IMPLEMENTATIONS = { "RedisKVStorage", "PGKVStorage", "MongoKVStorage", - # "TiDBKVStorage", ], "required_methods": ["get_by_id", "upsert"], }, @@ -16,9 +15,6 @@ STORAGE_IMPLEMENTATIONS = { "PGGraphStorage", "MongoGraphStorage", "MemgraphStorage", - # "AGEStorage", - # "TiDBGraphStorage", - # "GremlinStorage", ], "required_methods": ["upsert_node", "upsert_edge"], }, @@ -31,7 +27,6 @@ STORAGE_IMPLEMENTATIONS = { "QdrantVectorDBStorage", "MongoVectorDBStorage", # "ChromaVectorDBStorage", - # "TiDBVectorDBStorage", ], "required_methods": ["query", "upsert"], }, @@ -52,20 +47,17 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "JsonKVStorage": [], "MongoKVStorage": [], "RedisKVStorage": ["REDIS_URI"], - # "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], # Graph Storage Implementations "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], "MemgraphStorage": ["MEMGRAPH_URI"], - # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", "AGE_POSTGRES_USER", "AGE_POSTGRES_PASSWORD", ], - # "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], "PGGraphStorage": [ "POSTGRES_USER", "POSTGRES_PASSWORD", @@ -75,7 +67,6 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "NanoVectorDBStorage": [], "MilvusVectorDBStorage": [], "ChromaVectorDBStorage": [], - # "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "FaissVectorDBStorage": [], "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None @@ -102,14 +93,10 @@ STORAGES = { "RedisKVStorage": ".kg.redis_impl", "RedisDocStatusStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", - # "TiDBKVStorage": ".kg.tidb_impl", - # "TiDBVectorDBStorage": ".kg.tidb_impl", - # "TiDBGraphStorage": ".kg.tidb_impl", "PGKVStorage": ".kg.postgres_impl", "PGVectorStorage": ".kg.postgres_impl", "AGEStorage": ".kg.age_impl", "PGGraphStorage": ".kg.postgres_impl", - # "GremlinStorage": ".kg.gremlin_impl", "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", diff --git a/lightrag/kg/deprecated/age_impl.py b/lightrag/kg/deprecated/age_impl.py deleted file mode 100644 index 097b7b0b..00000000 --- a/lightrag/kg/deprecated/age_impl.py +++ /dev/null @@ -1,867 +0,0 @@ -import asyncio -import inspect -import json -import os -import sys -from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Union, final -import pipmaster as pm -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge - -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from lightrag.utils import logger - -from ..base import BaseGraphStorage - -if sys.platform.startswith("win"): - import asyncio.windows_events - - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") - -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - -import psycopg # type: ignore -from psycopg.rows import namedtuple_row # type: ignore -from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore - - -class AGEQueryException(Exception): - """Exception for the AGE queries.""" - - def __init__(self, exception: Union[str, Dict]) -> None: - if isinstance(exception, dict): - self.message = exception["message"] if "message" in exception else "unknown" - self.details = exception["details"] if "details" in exception else "unknown" - else: - self.message = exception - self.details = "unknown" - - def get_message(self) -> str: - return self.message - - def get_details(self) -> Any: - return self.details - - -@final -@dataclass -class AGEStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name): - print("no preloading of graph with AGE in production") - - def __init__(self, namespace, global_config, embedding_func): - super().__init__( - namespace=namespace, - global_config=global_config, - embedding_func=embedding_func, - ) - self._driver = None - self._driver_lock = asyncio.Lock() - DB = os.environ["AGE_POSTGRES_DB"].replace("\\", "\\\\").replace("'", "\\'") - USER = os.environ["AGE_POSTGRES_USER"].replace("\\", "\\\\").replace("'", "\\'") - PASSWORD = ( - os.environ["AGE_POSTGRES_PASSWORD"] - .replace("\\", "\\\\") - .replace("'", "\\'") - ) - HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'") - PORT = os.environ.get("AGE_POSTGRES_PORT", "8529") - self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") - - connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" - - self._driver = AsyncConnectionPool(connection_string, open=False) - - return None - - async def close(self): - if self._driver: - await self._driver.close() - self._driver = None - - async def __aexit__(self, exc_type, exc, tb): - if self._driver: - await self._driver.close() - - @staticmethod - def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: - """ - Convert a record returned from an age query to a dictionary - - Args: - record (): a record from an age query result - - Returns: - Dict[str, Any]: a dictionary representation of the record where - the dictionary key is the field name and the value is the - value converted to a python type - """ - # result holder - d = {} - - # prebuild a mapping of vertex_id to vertex mappings to be used - # later to build edges - vertices = {} - for k in record._fields: - v = getattr(record, k) - # agtype comes back '{key: value}::type' which must be parsed - if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - if dtype == "vertex": - vertex = json.loads(v) - vertices[vertex["id"]] = vertex.get("properties") - - # iterate returned fields and parse appropriately - for k in record._fields: - v = getattr(record, k) - if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - else: - dtype = "" - - if dtype == "vertex": - vertex = json.loads(v) - field = json.loads(v).get("properties") - if not field: - field = {} - field["label"] = AGEStorage._decode_graph_label(vertex["label"]) - d[k] = field - # convert edge from id-label->id by replacing id with node information - # we only do this if the vertex was also returned in the query - # this is an attempt to be consistent with neo4j implementation - elif dtype == "edge": - edge = json.loads(v) - d[k] = ( - vertices.get(edge["start_id"], {}), - edge[ - "label" - ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" - vertices.get(edge["end_id"], {}), - ) - else: - d[k] = json.loads(v) if isinstance(v, str) else v - - return d - - @staticmethod - def _format_properties( - properties: Dict[str, Any], _id: Union[str, None] = None - ) -> str: - """ - Convert a dictionary of properties to a string representation that - can be used in a cypher query insert/merge statement. - - Args: - properties (Dict[str,str]): a dictionary containing node/edge properties - id (Union[str, None]): the id of the node or None if none exists - - Returns: - str: the properties dictionary as a properly formatted string - """ - props = [] - # wrap property key in backticks to escape - for k, v in properties.items(): - prop = f"`{k}`: {json.dumps(v)}" - props.append(prop) - if _id is not None and "id" not in properties: - props.append( - f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" - ) - return "{" + ", ".join(props) + "}" - - @staticmethod - def _encode_graph_label(label: str) -> str: - """ - Since AGE suports only alphanumerical labels, we will encode generic label as HEX string - - Args: - label (str): the original label - - Returns: - str: the encoded label - """ - return "x" + label.encode().hex() - - @staticmethod - def _decode_graph_label(encoded_label: str) -> str: - """ - Since AGE suports only alphanumerical labels, we will encode generic label as HEX string - - Args: - encoded_label (str): the encoded label - - Returns: - str: the decoded label - """ - return bytes.fromhex(encoded_label.removeprefix("x")).decode() - - @staticmethod - def _get_col_name(field: str, idx: int) -> str: - """ - Convert a cypher return field to a pgsql select field - If possible keep the cypher column name, but create a generic name if necessary - - Args: - field (str): a return field from a cypher query to be formatted for pgsql - idx (int): the position of the field in the return statement - - Returns: - str: the field to be used in the pgsql select statement - """ - # remove white space - field = field.strip() - # if an alias is provided for the field, use it - if " as " in field: - return field.split(" as ")[-1].strip() - # if the return value is an unnamed primitive, give it a generic name - if field.isnumeric() or field in ("true", "false", "null"): - return f"column_{idx}" - # otherwise return the value stripping out some common special chars - return field.replace("(", "_").replace(")", "") - - @staticmethod - def _wrap_query(query: str, graph_name: str, **params: str) -> str: - """ - Convert a cypher query to an Apache Age compatible - sql query by wrapping the cypher query in ag_catalog.cypher, - casting results to agtype and building a select statement - - Args: - query (str): a valid cypher query - graph_name (str): the name of the graph to query - params (dict): parameters for the query - - Returns: - str: an equivalent pgsql query - """ - - # pgsql template - template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ - {query} - $$) AS ({fields});""" - - # if there are any returned fields they must be added to the pgsql query - if "return" in query.lower(): - # parse return statement to identify returned fields - fields = ( - query.lower() - .split("return")[-1] - .split("distinct")[-1] - .split("order by")[0] - .split("skip")[0] - .split("limit")[0] - .split(",") - ) - - # raise exception if RETURN * is found as we can't resolve the fields - if "*" in [x.strip() for x in fields]: - raise ValueError( - "AGE graph does not support 'RETURN *'" - + " statements in Cypher queries" - ) - - # get pgsql formatted field names - fields = [ - AGEStorage._get_col_name(field, idx) for idx, field in enumerate(fields) - ] - - # build resulting pgsql relation - fields_str = ", ".join( - [field.split(".")[-1] + " agtype" for field in fields] - ) - - # if no return statement we still need to return a single field of type agtype - else: - fields_str = "a agtype" - - select_str = "*" - - return template.format( - graph_name=graph_name, - query=query.format(**params), - fields=fields_str, - projection=select_str, - ) - - async def _query(self, query: str, **params: str) -> List[Dict[str, Any]]: - """ - Query the graph by taking a cypher query, converting it to an - age compatible query, executing it and converting the result - - Args: - query (str): a cypher query to be executed - params (dict): parameters for the query - - Returns: - List[Dict[str, Any]]: a list of dictionaries containing the result set - """ - # convert cypher query to pgsql/age query - wrapped_query = self._wrap_query(query, self.graph_name, **params) - - await self._driver.open() - - # create graph if it doesn't exist - async with self._get_pool_connection() as conn: - async with conn.cursor() as curs: - try: - await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute(f"SELECT create_graph('{self.graph_name}')") - await conn.commit() - except ( - psycopg.errors.InvalidSchemaName, - psycopg.errors.UniqueViolation, - ): - await conn.rollback() - - # execute the query, rolling back on an error - async with self._get_pool_connection() as conn: - async with conn.cursor(row_factory=namedtuple_row) as curs: - try: - await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute(wrapped_query) - await conn.commit() - except psycopg.Error as e: - await conn.rollback() - raise AGEQueryException( - { - "message": f"Error executing graph query: {query.format(**params)}", - "detail": str(e), - } - ) from e - - data = await curs.fetchall() - if data is None: - result = [] - # decode records - else: - result = [AGEStorage._record_to_dict(d) for d in data] - - return result - - async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') - - query = """ - MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - single_result = (await self._query(query, **params))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - single_result["node_exists"], - ) - - return single_result["node_exists"] - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - - query = """ - MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) - RETURN COUNT(r) > 0 AS edge_exists - """ - params = { - "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), - } - single_result = (await self._query(query, **params))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - single_result["edge_exists"], - ) - return single_result["edge_exists"] - - async def get_node(self, node_id: str) -> dict[str, str] | None: - entity_name_label = node_id.strip('"') - query = """ - MATCH (n:`{label}`) RETURN n - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - record = await self._query(query, **params) - if record: - node = record[0] - node_dict = node["n"] - logger.debug( - "{%s}: query: {%s}, result: {%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - node_dict, - ) - return node_dict - return None - - async def node_degree(self, node_id: str) -> int: - entity_name_label = node_id.strip('"') - - query = """ - MATCH (n:`{label}`)-[]->(x) - RETURN count(x) AS total_edge_count - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - record = (await self._query(query, **params))[0] - if record: - edge_count = int(record["total_edge_count"]) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - edge_count, - ) - return edge_count - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') - src_degree = await self.node_degree(entity_name_label_source) - trg_degree = await self.node_degree(entity_name_label_target) - - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - logger.debug( - "{%s}:query:src_Degree+trg_degree:result:{%s}", - inspect.currentframe().f_code.co_name, - degrees, - ) - return degrees - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - - query = """ - MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) - RETURN properties(r) as edge_properties - LIMIT 1 - """ - params = { - "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), - } - record = await self._query(query, **params) - if record and record[0] and record[0]["edge_properties"]: - result = record[0]["edge_properties"] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - result, - ) - return result - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ - node_label = source_node_id.strip('"') - - query = """ - MATCH (n:`{label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected - """ - params = {"label": AGEStorage._encode_graph_label(node_label)} - results = await self._query(query, **params) - edges = [] - for record in results: - source_node = record["n"] if record["n"] else None - connected_node = record["connected"] if record["connected"] else None - - source_label = ( - source_node["label"] if source_node and source_node["label"] else None - ) - target_label = ( - connected_node["label"] - if connected_node and connected_node["label"] - else None - ) - - if source_label and target_label: - edges.append((source_label, target_label)) - - return edges - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((AGEQueryException,)), - ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - """ - Upsert a node in the AGE database. - - Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties - """ - label = node_id.strip('"') - properties = node_data - - query = """ - MERGE (n:`{label}`) - SET n += {properties} - """ - params = { - "label": AGEStorage._encode_graph_label(label), - "properties": AGEStorage._format_properties(properties), - } - try: - await self._query(query, **params) - logger.debug( - "Upserted node with label '{%s}' and properties: {%s}", - label, - properties, - ) - except Exception as e: - logger.error("Error during upsert: {%s}", e) - raise - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((AGEQueryException,)), - ) - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - """ - Upsert an edge and its properties between two nodes identified by their labels. - - Args: - source_node_id (str): Label of the source node (used as identifier) - target_node_id (str): Label of the target node (used as identifier) - edge_data (dict): Dictionary of properties to set on the edge - """ - source_node_label = source_node_id.strip('"') - target_node_label = target_node_id.strip('"') - edge_properties = edge_data - - query = """ - MATCH (source:`{src_label}`) - WITH source - MATCH (target:`{tgt_label}`) - MERGE (source)-[r:DIRECTED]->(target) - SET r += {properties} - RETURN r - """ - params = { - "src_label": AGEStorage._encode_graph_label(source_node_label), - "tgt_label": AGEStorage._encode_graph_label(target_node_label), - "properties": AGEStorage._format_properties(edge_properties), - } - try: - await self._query(query, **params) - logger.debug( - "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", - source_node_label, - target_node_label, - edge_properties, - ) - except Exception as e: - logger.error("Error during edge upsert: {%s}", e) - raise - - @asynccontextmanager - async def _get_pool_connection(self, timeout: Optional[float] = None): - """Workaround for a psycopg_pool bug""" - - try: - connection = await self._driver.getconn(timeout=timeout) - except PoolTimeout: - await self._driver._add_connection(None) # workaround... - connection = await self._driver.getconn(timeout=timeout) - - try: - async with connection: - yield connection - finally: - await self._driver.putconn(connection) - - async def delete_node(self, node_id: str) -> None: - """Delete a node with the specified label - - Args: - node_id: The label of the node to delete - """ - entity_name_label = node_id.strip('"') - - query = """ - MATCH (n:`{label}`) - DETACH DELETE n - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - try: - await self._query(query, **params) - logger.debug(f"Deleted node with label '{entity_name_label}'") - except Exception as e: - logger.error(f"Error during node deletion: {str(e)}") - raise - - async def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node labels to be deleted - """ - for node in nodes: - await self.delete_node(node) - - async def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - for source, target in edges: - entity_name_label_source = source.strip('"') - entity_name_label_target = target.strip('"') - - query = """ - MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) - DELETE r - """ - params = { - "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), - } - try: - await self._query(query, **params) - logger.debug( - f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'" - ) - except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") - raise - - async def get_all_labels(self) -> list[str]: - """Get all node labels in the database - - Returns: - ["label1", "label2", ...] # Alphabetically sorted label list - """ - query = """ - MATCH (n) - RETURN DISTINCT labels(n) AS node_labels - """ - results = await self._query(query) - - all_labels = [] - for record in results: - if record and "node_labels" in record: - for label in record["node_labels"]: - if label: - # Decode label - decoded_label = AGEStorage._decode_graph_label(label) - all_labels.append(decoded_label) - - # Remove duplicates and sort - return sorted(list(set(all_labels))) - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: - """ - Retrieve a connected subgraph of nodes where the label includes the specified 'node_label'. - Maximum number of nodes is constrained by the environment variable 'MAX_GRAPH_NODES' (default: 1000). - When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence (nodes containing the specified label string) - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes - - Args: - node_label: String to match in node labels (will match any node containing this string in its label) - max_depth: Maximum depth of the graph. Defaults to 5. - - Returns: - KnowledgeGraph: Complete connected subgraph for specified node - """ - max_graph_nodes = int(os.getenv("MAX_GRAPH_NODES", 1000)) - result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - - # Handle special case for "*" label - if node_label == "*": - # Query all nodes and sort by degree - query = """ - MATCH (n) - OPTIONAL MATCH (n)-[r]-() - WITH n, count(r) AS degree - ORDER BY degree DESC - LIMIT {max_nodes} - RETURN n, degree - """ - params = {"max_nodes": max_graph_nodes} - nodes_result = await self._query(query, **params) - - # Add nodes to result - node_ids = [] - for record in nodes_result: - if "n" in record: - node = record["n"] - node_id = str(node.get("id", "")) - if node_id not in seen_nodes: - node_properties = {k: v for k, v in node.items()} - node_label = node.get("label", "") - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_label], - properties=node_properties, - ) - ) - seen_nodes.add(node_id) - node_ids.append(node_id) - - # Query edges between these nodes - if node_ids: - edges_query = """ - MATCH (a)-[r]->(b) - WHERE a.id IN {node_ids} AND b.id IN {node_ids} - RETURN a, r, b - """ - edges_params = {"node_ids": node_ids} - edges_result = await self._query(edges_query, **edges_params) - - # Add edges to result - for record in edges_result: - if "r" in record and "a" in record and "b" in record: - source = record["a"].get("id", "") - target = record["b"].get("id", "") - edge_id = f"{source}-{target}" - if edge_id not in seen_edges: - edge_properties = {k: v for k, v in record["r"].items()} - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=source, - target=target, - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - else: - # For specific label, use partial matching - entity_name_label = node_label.strip('"') - encoded_label = AGEStorage._encode_graph_label(entity_name_label) - - # Find matching start nodes - start_query = """ - MATCH (n:`{label}`) - RETURN n - """ - start_params = {"label": encoded_label} - start_nodes = await self._query(start_query, **start_params) - - if not start_nodes: - logger.warning(f"No nodes found with label '{entity_name_label}'!") - return result - - # Traverse graph from each start node - for start_node_record in start_nodes: - if "n" in start_node_record: - # Use BFS to traverse graph - query = """ - MATCH (start:`{label}`) - CALL { - MATCH path = (start)-[*0..{max_depth}]->(n) - RETURN nodes(path) AS path_nodes, relationships(path) AS path_rels - } - RETURN DISTINCT path_nodes, path_rels - """ - params = {"label": encoded_label, "max_depth": max_depth} - results = await self._query(query, **params) - - # Extract nodes and edges from results - for record in results: - if "path_nodes" in record: - # Process nodes - for node in record["path_nodes"]: - node_id = str(node.get("id", "")) - if ( - node_id not in seen_nodes - and len(seen_nodes) < max_graph_nodes - ): - node_properties = {k: v for k, v in node.items()} - node_label = node.get("label", "") - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_label], - properties=node_properties, - ) - ) - seen_nodes.add(node_id) - - if "path_rels" in record: - # Process edges - for rel in record["path_rels"]: - source = str(rel.get("start_id", "")) - target = str(rel.get("end_id", "")) - edge_id = f"{source}-{target}" - if edge_id not in seen_edges: - edge_properties = {k: v for k, v in rel.items()} - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=rel.get("label", "DIRECTED"), - source=source, - target=target, - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - return result - - async def index_done_callback(self) -> None: - # AGES handles persistence automatically - pass - - async def drop(self) -> dict[str, str]: - """Drop the storage by removing all nodes and relationships in the graph. - - Returns: - dict[str, str]: Status of the operation with keys 'status' and 'message' - """ - try: - query = """ - MATCH (n) - DETACH DELETE n - """ - await self._query(query) - logger.info(f"Successfully dropped all data from graph {self.graph_name}") - return {"status": "success", "message": "graph data dropped"} - except Exception as e: - logger.error(f"Error dropping graph {self.graph_name}: {e}") - return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/deprecated/gremlin_impl.py b/lightrag/kg/deprecated/gremlin_impl.py deleted file mode 100644 index 32dbcc4e..00000000 --- a/lightrag/kg/deprecated/gremlin_impl.py +++ /dev/null @@ -1,686 +0,0 @@ -import asyncio -import inspect -import json -import os -import pipmaster as pm -from dataclasses import dataclass -from typing import Any, Dict, List, final - -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from lightrag.utils import logger - -from ..base import BaseGraphStorage - -if not pm.is_installed("gremlinpython"): - pm.install("gremlinpython") - -from gremlin_python.driver import client, serializer # type: ignore -from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore -from gremlin_python.driver.protocol import GremlinServerError # type: ignore - - -@final -@dataclass -class GremlinStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name): - print("no preloading of graph with Gremlin in production") - - def __init__(self, namespace, global_config, embedding_func): - super().__init__( - namespace=namespace, - global_config=global_config, - embedding_func=embedding_func, - ) - - self._driver = None - self._driver_lock = asyncio.Lock() - - USER = os.environ.get("GREMLIN_USER", "") - PASSWORD = os.environ.get("GREMLIN_PASSWORD", "") - HOST = os.environ["GREMLIN_HOST"] - PORT = int(os.environ["GREMLIN_PORT"]) - - # TraversalSource, a custom one has to be created manually, - # default it "g" - SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g") - - # All vertices will have graph={GRAPH} property, so that we can - # have several logical graphs for one source - GRAPH = GremlinStorage._to_value_map( - os.environ.get("GREMLIN_GRAPH", "LightRAG") - ) - - self.graph_name = GRAPH - - self._driver = client.Client( - f"ws://{HOST}:{PORT}/gremlin", - SOURCE, - username=USER, - password=PASSWORD, - message_serializer=serializer.GraphSONSerializersV3d0(), - transport_factory=lambda: AiohttpTransport(call_from_event_loop=True), - ) - - async def close(self): - if self._driver: - self._driver.close() - self._driver = None - - async def __aexit__(self, exc_type, exc, tb): - if self._driver: - self._driver.close() - - async def index_done_callback(self) -> None: - # Gremlin handles persistence automatically - pass - - @staticmethod - def _to_value_map(value: Any) -> str: - """Dump supported Python object as Gremlin valueMap""" - json_str = json.dumps(value, ensure_ascii=False, sort_keys=False) - parsed_str = json_str.replace("'", r"\'") - - # walk over the string and replace curly brackets with square brackets - # outside of strings, as well as replace double quotes with single quotes - # and "deescape" double quotes inside of strings - outside_str = True - escaped = False - remove_indices = [] - for i, c in enumerate(parsed_str): - if escaped: - # previous character was an "odd" backslash - escaped = False - if c == '"': - # we want to "deescape" double quotes: store indices to delete - remove_indices.insert(0, i - 1) - elif c == "\\": - escaped = True - elif c == '"': - outside_str = not outside_str - parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :] - elif c == "{" and outside_str: - parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :] - elif c == "}" and outside_str: - parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :] - for idx in remove_indices: - parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :] - return parsed_str - - @staticmethod - def _convert_properties(properties: Dict[str, Any]) -> str: - """Create chained .property() commands from properties dict""" - props = [] - for k, v in properties.items(): - prop_name = GremlinStorage._to_value_map(k) - props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})") - return "".join(props) - - @staticmethod - def _fix_name(name: str) -> str: - """Strip double quotes and format as a proper field name""" - name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'")) - - return name - - async def _query(self, query: str) -> List[Dict[str, Any]]: - """ - Query the Gremlin graph - - Args: - query (str): a query to be executed - - Returns: - List[Dict[str, Any]]: a list of dictionaries containing the result set - """ - - result = list(await asyncio.wrap_future(self._driver.submit_async(query))) - if result: - result = result[0] - - return result - - async def has_node(self, node_id: str) -> bool: - entity_name = GremlinStorage._fix_name(node_id) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .limit(1) - .count() - .project('has_node') - .by(__.choose(__.is(gt(0)), constant(true), constant(false))) - """ - result = await self._query(query) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - result[0]["has_node"], - ) - - return result[0]["has_node"] - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_source = GremlinStorage._fix_name(source_node_id) - entity_name_target = GremlinStorage._fix_name(target_node_id) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_source}) - .outE() - .inV().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_target}) - .limit(1) - .count() - .project('has_edge') - .by(__.choose(__.is(gt(0)), constant(true), constant(false))) - """ - result = await self._query(query) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - result[0]["has_edge"], - ) - - return result[0]["has_edge"] - - async def get_node(self, node_id: str) -> dict[str, str] | None: - entity_name = GremlinStorage._fix_name(node_id) - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .limit(1) - .project('properties') - .by(elementMap()) - """ - result = await self._query(query) - if result: - node = result[0] - node_dict = node["properties"] - logger.debug( - "{%s}: query: {%s}, result: {%s}", - inspect.currentframe().f_code.co_name, - query.format, - node_dict, - ) - return node_dict - - async def node_degree(self, node_id: str) -> int: - entity_name = GremlinStorage._fix_name(node_id) - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .outE() - .inV().has('graph', {self.graph_name}) - .count() - .project('total_edge_count') - .by() - """ - result = await self._query(query) - edge_count = result[0]["total_edge_count"] - - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - edge_count, - ) - - return edge_count - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - src_degree = await self.node_degree(src_id) - trg_degree = await self.node_degree(tgt_id) - - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - logger.debug( - "{%s}:query:src_Degree+trg_degree:result:{%s}", - inspect.currentframe().f_code.co_name, - degrees, - ) - return degrees - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - entity_name_source = GremlinStorage._fix_name(source_node_id) - entity_name_target = GremlinStorage._fix_name(target_node_id) - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_source}) - .outE() - .inV().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_target}) - .limit(1) - .project('edge_properties') - .by(__.bothE().elementMap()) - """ - result = await self._query(query) - if result: - edge_properties = result[0]["edge_properties"] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - edge_properties, - ) - return edge_properties - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - node_name = GremlinStorage._fix_name(source_node_id) - query = f"""g - .E() - .filter( - __.or( - __.outV().has('graph', {self.graph_name}) - .has('entity_name', {node_name}), - __.inV().has('graph', {self.graph_name}) - .has('entity_name', {node_name}) - ) - ) - .project('source_name', 'target_name') - .by(__.outV().values('entity_name')) - .by(__.inV().values('entity_name')) - """ - result = await self._query(query) - edges = [(res["source_name"], res["target_name"]) for res in result] - - return edges - - @retry( - stop=stop_after_attempt(10), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((GremlinServerError,)), - ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - """ - Upsert a node in the Gremlin graph. - - Args: - node_id: The unique identifier for the node (used as name) - node_data: Dictionary of node properties - """ - name = GremlinStorage._fix_name(node_id) - properties = GremlinStorage._convert_properties(node_data) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {name}) - .fold() - .coalesce( - __.unfold(), - __.addV('ENTITY') - .property('graph', {self.graph_name}) - .property('entity_name', {name}) - ) - {properties} - """ - - try: - await self._query(query) - logger.debug( - "Upserted node with name {%s} and properties: {%s}", - name, - properties, - ) - except Exception as e: - logger.error("Error during upsert: {%s}", e) - raise - - @retry( - stop=stop_after_attempt(10), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((GremlinServerError,)), - ) - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - """ - Upsert an edge and its properties between two nodes identified by their names. - - Args: - source_node_id (str): Name of the source node (used as identifier) - target_node_id (str): Name of the target node (used as identifier) - edge_data (dict): Dictionary of properties to set on the edge - """ - source_node_name = GremlinStorage._fix_name(source_node_id) - target_node_name = GremlinStorage._fix_name(target_node_id) - edge_properties = GremlinStorage._convert_properties(edge_data) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {source_node_name}).as('source') - .V().has('graph', {self.graph_name}) - .has('entity_name', {target_node_name}).as('target') - .coalesce( - __.select('source').outE('DIRECTED').where(__.inV().as('target')), - __.select('source').addE('DIRECTED').to(__.select('target')) - ) - .property('graph', {self.graph_name}) - {edge_properties} - """ - try: - await self._query(query) - logger.debug( - "Upserted edge from {%s} to {%s} with properties: {%s}", - source_node_name, - target_node_name, - edge_properties, - ) - except Exception as e: - logger.error("Error during edge upsert: {%s}", e) - raise - - async def delete_node(self, node_id: str) -> None: - """Delete a node with the specified entity_name - - Args: - node_id: The entity_name of the node to delete - """ - entity_name = GremlinStorage._fix_name(node_id) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .drop() - """ - try: - await self._query(query) - logger.debug( - "{%s}: Deleted node with entity_name '%s'", - inspect.currentframe().f_code.co_name, - entity_name, - ) - except Exception as e: - logger.error(f"Error during node deletion: {str(e)}") - raise - - async def get_all_labels(self) -> list[str]: - """ - Get all node entity_names in the graph - Returns: - [entity_name1, entity_name2, ...] # Alphabetically sorted entity_name list - """ - query = f"""g - .V().has('graph', {self.graph_name}) - .values('entity_name') - .dedup() - .order() - """ - try: - result = await self._query(query) - labels = result if result else [] - logger.debug( - "{%s}: Retrieved %d labels", - inspect.currentframe().f_code.co_name, - len(labels), - ) - return labels - except Exception as e: - logger.error(f"Error retrieving labels: {str(e)}") - return [] - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: - """ - Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. - Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - - Args: - node_label: Entity name of the starting node - max_depth: Maximum depth of the subgraph - - Returns: - KnowledgeGraph object containing nodes and edges - """ - result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - - # Get maximum number of graph nodes from environment variable, default is 1000 - MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - - entity_name = GremlinStorage._fix_name(node_label) - - # Handle special case for "*" label - if node_label == "*": - # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) - query = f"""g - .V().has('graph', {self.graph_name}) - .limit({MAX_GRAPH_NODES}) - .elementMap() - """ - nodes_result = await self._query(query) - - # Add nodes to result - for node_data in nodes_result: - node_id = node_data.get("entity_name", str(node_data.get("id", ""))) - if str(node_id) in seen_nodes: - continue - - # Create node with properties - node_properties = { - k: v for k, v in node_data.items() if k not in ["id", "label"] - } - - result.nodes.append( - KnowledgeGraphNode( - id=str(node_id), - labels=[str(node_id)], - properties=node_properties, - ) - ) - seen_nodes.add(str(node_id)) - - # Get and add edges - if nodes_result: - query = f"""g - .V().has('graph', {self.graph_name}) - .limit({MAX_GRAPH_NODES}) - .outE() - .inV().has('graph', {self.graph_name}) - .limit({MAX_GRAPH_NODES}) - .path() - .by(elementMap()) - .by(elementMap()) - .by(elementMap()) - """ - edges_result = await self._query(query) - - for path in edges_result: - if len(path) >= 3: # source -> edge -> target - source = path[0] - edge_data = path[1] - target = path[2] - - source_id = source.get("entity_name", str(source.get("id", ""))) - target_id = target.get("entity_name", str(target.get("id", ""))) - - edge_id = f"{source_id}-{target_id}" - if edge_id in seen_edges: - continue - - # Create edge with properties - edge_properties = { - k: v - for k, v in edge_data.items() - if k not in ["id", "label"] - } - - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source_id), - target=str(target_id), - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - else: - # Search for specific node and get its neighborhood - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .repeat(__.both().simplePath().dedup()) - .times({max_depth}) - .emit() - .dedup() - .limit({MAX_GRAPH_NODES}) - .elementMap() - """ - nodes_result = await self._query(query) - - # Add nodes to result - for node_data in nodes_result: - node_id = node_data.get("entity_name", str(node_data.get("id", ""))) - if str(node_id) in seen_nodes: - continue - - # Create node with properties - node_properties = { - k: v for k, v in node_data.items() if k not in ["id", "label"] - } - - result.nodes.append( - KnowledgeGraphNode( - id=str(node_id), - labels=[str(node_id)], - properties=node_properties, - ) - ) - seen_nodes.add(str(node_id)) - - # Get edges between the nodes in the result - if nodes_result: - node_ids = [ - n.get("entity_name", str(n.get("id", ""))) for n in nodes_result - ] - node_ids_query = ", ".join( - [GremlinStorage._to_value_map(nid) for nid in node_ids] - ) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', within({node_ids_query})) - .outE() - .where(inV().has('graph', {self.graph_name}) - .has('entity_name', within({node_ids_query}))) - .path() - .by(elementMap()) - .by(elementMap()) - .by(elementMap()) - """ - edges_result = await self._query(query) - - for path in edges_result: - if len(path) >= 3: # source -> edge -> target - source = path[0] - edge_data = path[1] - target = path[2] - - source_id = source.get("entity_name", str(source.get("id", ""))) - target_id = target.get("entity_name", str(target.get("id", ""))) - - edge_id = f"{source_id}-{target_id}" - if edge_id in seen_edges: - continue - - # Create edge with properties - edge_properties = { - k: v - for k, v in edge_data.items() - if k not in ["id", "label"] - } - - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source_id), - target=str(target_id), - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - - logger.info( - "Subgraph query successful | Node count: %d | Edge count: %d", - len(result.nodes), - len(result.edges), - ) - return result - - async def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node entity_names to be deleted - """ - for node in nodes: - await self.delete_node(node) - - async def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - for source, target in edges: - entity_name_source = GremlinStorage._fix_name(source) - entity_name_target = GremlinStorage._fix_name(target) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_source}) - .outE() - .where(inV().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_target})) - .drop() - """ - try: - await self._query(query) - logger.debug( - "{%s}: Deleted edge from '%s' to '%s'", - inspect.currentframe().f_code.co_name, - entity_name_source, - entity_name_target, - ) - except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") - raise - - async def drop(self) -> dict[str, str]: - """Drop the storage by removing all nodes and relationships in the graph. - - This function deletes all nodes with the specified graph name property, - which automatically removes all associated edges. - - Returns: - dict[str, str]: Status of the operation with keys 'status' and 'message' - """ - try: - query = f"""g - .V().has('graph', {self.graph_name}) - .drop() - """ - await self._query(query) - logger.info(f"Successfully dropped all data from graph {self.graph_name}") - return {"status": "success", "message": "graph data dropped"} - except Exception as e: - logger.error(f"Error dropping graph {self.graph_name}: {e}") - return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/deprecated/tidb_impl.py b/lightrag/kg/deprecated/tidb_impl.py deleted file mode 100644 index 0d5dfca3..00000000 --- a/lightrag/kg/deprecated/tidb_impl.py +++ /dev/null @@ -1,1230 +0,0 @@ -import asyncio -import os -from dataclasses import dataclass, field -from typing import Any, Union, final -import time -import numpy as np - -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge - - -from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage -from ..namespace import NameSpace, is_namespace -from ..utils import logger - -import pipmaster as pm -import configparser - -if not pm.is_installed("pymysql"): - pm.install("pymysql") -if not pm.is_installed("sqlalchemy"): - pm.install("sqlalchemy") - -from sqlalchemy import create_engine, text # type: ignore - - -def sanitize_sensitive_info(data: dict) -> dict: - sanitized_data = data.copy() - sensitive_fields = [ - "password", - "user", - "host", - "database", - "port", - "ssl_verify_cert", - "ssl_verify_identity", - ] - for field_name in sensitive_fields: - if field_name in sanitized_data: - sanitized_data[field_name] = "***" - return sanitized_data - - -class TiDB: - def __init__(self, config, **kwargs): - self.host = config.get("host", None) - self.port = config.get("port", None) - self.user = config.get("user", None) - self.password = config.get("password", None) - self.database = config.get("database", None) - self.workspace = config.get("workspace", None) - connection_string = ( - f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" - f"?ssl_verify_cert=true&ssl_verify_identity=true" - ) - - try: - self.engine = create_engine(connection_string) - logger.info("Connected to TiDB database") - except Exception as e: - logger.error("Failed to connect to TiDB database") - logger.error(f"TiDB database error: {e}") - raise - - async def _migrate_timestamp_columns(self): - """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC""" - # Not implemented yet - pass - - async def check_tables(self): - # First create all tables - for k, v in TABLES.items(): - try: - await self.query(f"SELECT 1 FROM {k}".format(k=k)) - except Exception as e: - logger.error("Failed to check table in TiDB database") - logger.error(f"TiDB database error: {e}") - try: - await self.execute(v["ddl"]) - logger.info("Created table in TiDB database") - except Exception as e: - logger.error("Failed to create table in TiDB database") - logger.error(f"TiDB database error: {e}") - - # After all tables are created, try to migrate timestamp fields - try: - await self._migrate_timestamp_columns() - except Exception as e: - logger.error(f"TiDB, Failed to migrate timestamp columns: {e}") - # Don't raise exceptions, allow initialization process to continue - - async def query( - self, sql: str, params: dict = None, multirows: bool = False - ) -> Union[dict, None]: - if params is None: - params = {"workspace": self.workspace} - else: - params.update({"workspace": self.workspace}) - with self.engine.connect() as conn, conn.begin(): - try: - result = conn.execute(text(sql), params) - except Exception as e: - sanitized_params = sanitize_sensitive_info(params) - sanitized_error = sanitize_sensitive_info({"error": str(e)}) - logger.error( - f"Tidb database,\nsql:{sql},\nparams:{sanitized_params},\nerror:{sanitized_error}" - ) - raise - if multirows: - rows = result.all() - if rows: - data = [dict(zip(result.keys(), row)) for row in rows] - else: - data = [] - else: - row = result.first() - if row: - data = dict(zip(result.keys(), row)) - else: - data = None - return data - - async def execute(self, sql: str, data: list | dict = None): - # logger.info("go into TiDBDB execute method") - try: - with self.engine.connect() as conn, conn.begin(): - if data is None: - conn.execute(text(sql)) - else: - conn.execute(text(sql), parameters=data) - except Exception as e: - sanitized_data = sanitize_sensitive_info(data) if data else None - sanitized_error = sanitize_sensitive_info({"error": str(e)}) - logger.error( - f"Tidb database,\nsql:{sql},\ndata:{sanitized_data},\nerror:{sanitized_error}" - ) - raise - - -class ClientManager: - _instances: dict[str, Any] = {"db": None, "ref_count": 0} - _lock = asyncio.Lock() - - @staticmethod - def get_config() -> dict[str, Any]: - config = configparser.ConfigParser() - config.read("config.ini", "utf-8") - - return { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - - @classmethod - async def get_client(cls) -> TiDB: - async with cls._lock: - if cls._instances["db"] is None: - config = ClientManager.get_config() - db = TiDB(config) - await db.check_tables() - cls._instances["db"] = db - cls._instances["ref_count"] = 0 - cls._instances["ref_count"] += 1 - return cls._instances["db"] - - @classmethod - async def release_client(cls, db: TiDB): - async with cls._lock: - if db is not None: - if db is cls._instances["db"]: - cls._instances["ref_count"] -= 1 - if cls._instances["ref_count"] == 0: - cls._instances["db"] = None - - -@final -@dataclass -class TiDBKVStorage(BaseKVStorage): - db: TiDB = field(default=None) - - def __post_init__(self): - self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - ################ QUERY METHODS ################ - async def get_all(self) -> dict[str, Any]: - """Get all data from storage - - Returns: - Dictionary containing all stored data - """ - async with self._storage_lock: - return dict(self._data) - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Fetch doc_full data by id.""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"id": id} - response = await self.db.query(SQL, params) - return response if response else None - - # Query by id - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Fetch doc_chunks data by id""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - return await self.db.query(SQL, multirows=True) - - async def filter_keys(self, keys: set[str]) -> set[str]: - SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=namespace_to_table_name(self.namespace), - id_field=namespace_to_id(self.namespace), - ids=",".join([f"'{id}'" for id in keys]), - ) - try: - await self.db.query(SQL) - except Exception as e: - logger.error(f"Tidb database,\nsql:{SQL},\nkeys:{keys},\nerror:{e}") - res = await self.db.query(SQL, multirows=True) - if res: - exist_keys = [key["id"] for key in res] - data = set([s for s in keys if s not in exist_keys]) - else: - exist_keys = [] - data = set([s for s in keys if s not in exist_keys]) - return data - - ################ INSERT full_doc AND chunks ################ - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") - if not data: - return - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - list_data = [ - { - "__id__": k, - **{k1: v1 for k1, v1 in v.items()}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - - # Get current time as UNIX timestamp - current_time = int(time.time()) - - merge_sql = SQL_TEMPLATES["upsert_chunk"] - data = [] - for item in list_data: - data.append( - { - "id": item["__id__"], - "content": item["content"], - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content_vector": f"{item['__vector__'].tolist()}", - "workspace": self.db.workspace, - "timestamp": current_time, - } - ) - await self.db.execute(merge_sql, data) - - if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): - merge_sql = SQL_TEMPLATES["upsert_doc_full"] - data = [] - for k, v in self._data.items(): - data.append( - { - "id": k, - "content": v["content"], - "workspace": self.db.workspace, - } - ) - await self.db.execute(merge_sql, data) - return left_data - - async def index_done_callback(self) -> None: - # Ti handles persistence automatically - pass - - async def delete(self, ids: list[str]) -> None: - """Delete records with specified IDs from the storage. - - Args: - ids: List of record IDs to be deleted - """ - if not ids: - return - - try: - table_name = namespace_to_table_name(self.namespace) - id_field = namespace_to_id(self.namespace) - - if not table_name or not id_field: - logger.error(f"Unknown namespace for deletion: {self.namespace}") - return - - ids_list = ",".join([f"'{id}'" for id in ids]) - delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" - - await self.db.execute(delete_sql, {"workspace": self.db.workspace}) - logger.info( - f"Successfully deleted {len(ids)} records from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error deleting records from {self.namespace}: {e}") - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - -@final -@dataclass -class TiDBVectorDBStorage(BaseVectorStorage): - db: TiDB | None = field(default=None) - - def __post_init__(self): - self._client_file_name = os.path.join( - self.global_config["working_dir"], f"vdb_{self.namespace}.json" - ) - self._max_batch_size = self.global_config["embedding_batch_num"] - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") - if cosine_threshold is None: - raise ValueError( - "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" - ) - self.cosine_better_than_threshold = cosine_threshold - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - async def query( - self, query: str, top_k: int, ids: list[str] | None = None - ) -> list[dict[str, Any]]: - """Search from tidb vector""" - embeddings = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - embedding = embeddings[0] - - embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" - - params = { - "embedding_string": embedding_string, - "top_k": top_k, - "better_than_threshold": self.cosine_better_than_threshold, - } - - results = await self.db.query( - SQL_TEMPLATES[self.namespace], params=params, multirows=True - ) - if not results: - return [] - return results - - ###### INSERT entities And relationships ###### - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - if not data: - return - logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") - - # Get current time as UNIX timestamp - import time - - current_time = int(time.time()) - - list_data = [ - { - "id": k, - "timestamp": current_time, - **{k1: v1 for k1, v1 in v.items()}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embedding_tasks = [self.embedding_func(batch) for batch in batches] - embeddings_list = await asyncio.gather(*embedding_tasks) - - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["content_vector"] = embeddings[i] - - if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): - for item in list_data: - param = { - "id": item["id"], - "content": item["content"], - "tokens": item.get("tokens", 0), - "chunk_order_index": item.get("chunk_order_index", 0), - "full_doc_id": item.get("full_doc_id", ""), - "content_vector": f"{item['content_vector'].tolist()}", - "workspace": self.db.workspace, - "timestamp": item["timestamp"], - } - await self.db.execute(SQL_TEMPLATES["upsert_chunk"], param) - elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): - for item in list_data: - param = { - "id": item["id"], - "name": item["entity_name"], - "content": item["content"], - "content_vector": f"{item['content_vector'].tolist()}", - "workspace": self.db.workspace, - "timestamp": item["timestamp"], - } - await self.db.execute(SQL_TEMPLATES["upsert_entity"], param) - elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): - for item in list_data: - param = { - "id": item["id"], - "source_name": item["src_id"], - "target_name": item["tgt_id"], - "content": item["content"], - "content_vector": f"{item['content_vector'].tolist()}", - "workspace": self.db.workspace, - "timestamp": item["timestamp"], - } - await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param) - - async def delete(self, ids: list[str]) -> None: - """Delete vectors with specified IDs from the storage. - - Args: - ids: List of vector IDs to be deleted - """ - if not ids: - return - - table_name = namespace_to_table_name(self.namespace) - id_field = namespace_to_id(self.namespace) - - if not table_name or not id_field: - logger.error(f"Unknown namespace for vector deletion: {self.namespace}") - return - - ids_list = ",".join([f"'{id}'" for id in ids]) - delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" - - try: - await self.db.execute(delete_sql, {"workspace": self.db.workspace}) - logger.debug( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - - async def delete_entity(self, entity_name: str) -> None: - """Delete an entity by its name from the vector storage. - - Args: - entity_name: The name of the entity to delete - """ - try: - # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace AND name = :entity_name""" - - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} - ) - logger.debug(f"Successfully deleted entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") - - async def delete_entity_relation(self, entity_name: str) -> None: - """Delete all relations associated with an entity. - - Args: - entity_name: The name of the entity whose relations should be deleted - """ - try: - # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)""" - - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} - ) - logger.debug(f"Successfully deleted relations for entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting relations for entity {entity_name}: {e}") - - async def index_done_callback(self) -> None: - # Ti handles persistence automatically - pass - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get vector data by its ID - - Args: - id: The unique identifier of the vector - - Returns: - The vector data if found, or None if not found - """ - try: - # Determine which table to query based on namespace - if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: - sql_template = """ - SELECT entity_id as id, name as entity_name, entity_type, description, content, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_NODES - WHERE entity_id = :entity_id AND workspace = :workspace - """ - params = {"entity_id": id, "workspace": self.db.workspace} - elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: - sql_template = """ - SELECT relation_id as id, source_name as src_id, target_name as tgt_id, - keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_EDGES - WHERE relation_id = :relation_id AND workspace = :workspace - """ - params = {"relation_id": id, "workspace": self.db.workspace} - elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: - sql_template = """ - SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_DOC_CHUNKS - WHERE chunk_id = :chunk_id AND workspace = :workspace - """ - params = {"chunk_id": id, "workspace": self.db.workspace} - else: - logger.warning( - f"Namespace {self.namespace} not supported for get_by_id" - ) - return None - - result = await self.db.query(sql_template, params=params) - return result - except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") - return None - - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get multiple vector data by their IDs - - Args: - ids: List of unique identifiers - - Returns: - List of vector data objects that were found - """ - if not ids: - return [] - - try: - # Format IDs for SQL IN clause - ids_str = ", ".join([f"'{id}'" for id in ids]) - - # Determine which table to query based on namespace - if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: - sql_template = f""" - SELECT entity_id as id, name as entity_name, entity_type, description, content, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_NODES - WHERE entity_id IN ({ids_str}) AND workspace = :workspace - """ - elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: - sql_template = f""" - SELECT relation_id as id, source_name as src_id, target_name as tgt_id, - keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_EDGES - WHERE relation_id IN ({ids_str}) AND workspace = :workspace - """ - elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: - sql_template = f""" - SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_DOC_CHUNKS - WHERE chunk_id IN ({ids_str}) AND workspace = :workspace - """ - else: - logger.warning( - f"Namespace {self.namespace} not supported for get_by_ids" - ) - return [] - - params = {"workspace": self.db.workspace} - results = await self.db.query(sql_template, params=params, multirows=True) - return results if results else [] - except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") - return [] - - -@final -@dataclass -class TiDBGraphStorage(BaseGraphStorage): - db: TiDB = field(default=None) - - def __post_init__(self): - self._max_batch_size = self.global_config["embedding_batch_num"] - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - #################### upsert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - entity_name = node_id - entity_type = node_data["entity_type"] - description = node_data["description"] - source_id = node_data["source_id"] - logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") - content = entity_name + description - contents = [content] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - sql = SQL_TEMPLATES["upsert_node"] - data = { - "workspace": self.db.workspace, - "name": entity_name, - "entity_type": entity_type, - "description": description, - "source_chunk_id": source_id, - "content": content, - "content_vector": f"{content_vector.tolist()}", - } - await self.db.execute(sql, data) - - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - source_name = source_node_id - target_name = target_node_id - weight = edge_data["weight"] - keywords = edge_data["keywords"] - description = edge_data["description"] - source_chunk_id = edge_data["source_id"] - logger.debug( - f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" - ) - - content = keywords + source_name + target_name + description - contents = [content] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["upsert_edge"] - data = { - "workspace": self.db.workspace, - "source_name": source_name, - "target_name": target_name, - "weight": weight, - "keywords": keywords, - "description": description, - "source_chunk_id": source_chunk_id, - "content": content, - "content_vector": f"{content_vector.tolist()}", - } - await self.db.execute(merge_sql, data) - - # Query - - async def has_node(self, node_id: str) -> bool: - sql = SQL_TEMPLATES["has_entity"] - param = {"name": node_id, "workspace": self.db.workspace} - has = await self.db.query(sql, param) - return has["cnt"] != 0 - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - sql = SQL_TEMPLATES["has_relationship"] - param = { - "source_name": source_node_id, - "target_name": target_node_id, - "workspace": self.db.workspace, - } - has = await self.db.query(sql, param) - return has["cnt"] != 0 - - async def node_degree(self, node_id: str) -> int: - sql = SQL_TEMPLATES["node_degree"] - param = {"name": node_id, "workspace": self.db.workspace} - result = await self.db.query(sql, param) - return result["cnt"] - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) - return degree - - async def get_node(self, node_id: str) -> dict[str, str] | None: - sql = SQL_TEMPLATES["get_node"] - param = {"name": node_id, "workspace": self.db.workspace} - return await self.db.query(sql, param) - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - sql = SQL_TEMPLATES["get_edge"] - param = { - "source_name": source_node_id, - "target_name": target_node_id, - "workspace": self.db.workspace, - } - return await self.db.query(sql, param) - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - sql = SQL_TEMPLATES["get_node_edges"] - param = {"source_name": source_node_id, "workspace": self.db.workspace} - res = await self.db.query(sql, param, multirows=True) - if res: - data = [(i["source_name"], i["target_name"]) for i in res] - return data - else: - return [] - - async def index_done_callback(self) -> None: - # Ti handles persistence automatically - pass - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - drop_sql = """ - DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace; - DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace; - """ - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "graph data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - async def delete_node(self, node_id: str) -> None: - """Delete a node and all its related edges - - Args: - node_id: The ID of the node to delete - """ - # First delete all edges related to this node - await self.db.execute( - SQL_TEMPLATES["delete_node_edges"], - {"name": node_id, "workspace": self.db.workspace}, - ) - - # Then delete the node itself - await self.db.execute( - SQL_TEMPLATES["delete_node"], - {"name": node_id, "workspace": self.db.workspace}, - ) - - logger.debug( - f"Node {node_id} and its related edges have been deleted from the graph" - ) - - async def get_all_labels(self) -> list[str]: - """Get all entity types (labels) in the database - - Returns: - List of labels sorted alphabetically - """ - result = await self.db.query( - SQL_TEMPLATES["get_all_labels"], - {"workspace": self.db.workspace}, - multirows=True, - ) - - if not result: - return [] - - # Extract all labels - return [item["label"] for item in result] - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: - """ - Get a connected subgraph of nodes matching the specified label - Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) - - Args: - node_label: The node label to match - max_depth: Maximum depth of the subgraph - - Returns: - KnowledgeGraph object containing nodes and edges - """ - result = KnowledgeGraph() - MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - - # Get matching nodes - if node_label == "*": - # Handle special case, get all nodes - node_results = await self.db.query( - SQL_TEMPLATES["get_all_nodes"], - {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, - multirows=True, - ) - else: - # Get nodes matching the label - label_pattern = f"%{node_label}%" - node_results = await self.db.query( - SQL_TEMPLATES["get_matching_nodes"], - {"workspace": self.db.workspace, "label_pattern": label_pattern}, - multirows=True, - ) - - if not node_results: - logger.warning(f"No nodes found matching label {node_label}") - return result - - # Limit the number of returned nodes - if len(node_results) > MAX_GRAPH_NODES: - node_results = node_results[:MAX_GRAPH_NODES] - - # Extract node names for edge query - node_names = [node["name"] for node in node_results] - node_names_str = ",".join([f"'{name}'" for name in node_names]) - - # Add nodes to result - for node in node_results: - node_properties = { - k: v for k, v in node.items() if k not in ["id", "name", "entity_type"] - } - result.nodes.append( - KnowledgeGraphNode( - id=node["name"], - labels=[node["entity_type"]] - if node.get("entity_type") - else [node["name"]], - properties=node_properties, - ) - ) - - # Get related edges - edge_results = await self.db.query( - SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), - {"workspace": self.db.workspace}, - multirows=True, - ) - - if edge_results: - # Add edges to result - for edge in edge_results: - # Only include edges related to selected nodes - if ( - edge["source_name"] in node_names - and edge["target_name"] in node_names - ): - edge_id = f"{edge['source_name']}-{edge['target_name']}" - edge_properties = { - k: v - for k, v in edge.items() - if k not in ["id", "source_name", "target_name"] - } - - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="RELATED", - source=edge["source_name"], - target=edge["target_name"], - properties=edge_properties, - ) - ) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - return result - - async def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node IDs to delete - """ - for node_id in nodes: - await self.delete_node(node_id) - - async def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to delete, each edge is a (source, target) tuple - """ - for source, target in edges: - await self.db.execute( - SQL_TEMPLATES["remove_multiple_edges"], - {"source": source, "target": target, "workspace": self.db.workspace}, - ) - - -N_T = { - NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", - NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES", - NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", -} -N_ID = { - NameSpace.KV_STORE_FULL_DOCS: "doc_id", - NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id", - NameSpace.VECTOR_STORE_CHUNKS: "chunk_id", - NameSpace.VECTOR_STORE_ENTITIES: "entity_id", - NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id", -} - - -def namespace_to_table_name(namespace: str) -> str: - for k, v in N_T.items(): - if is_namespace(namespace, k): - return v - - -def namespace_to_id(namespace: str) -> str: - for k, v in N_ID.items(): - if is_namespace(namespace, k): - return v - - -TABLES = { - "LIGHTRAG_DOC_FULL": { - "ddl": """ - CREATE TABLE LIGHTRAG_DOC_FULL ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `doc_id` VARCHAR(256) NOT NULL, - `workspace` varchar(1024), - `content` LONGTEXT, - `meta` JSON, - `createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - `updatetime` TIMESTAMP DEFAULT NULL, - UNIQUE KEY (`doc_id`) - ); - """ - }, - "LIGHTRAG_DOC_CHUNKS": { - "ddl": """ - CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `chunk_id` VARCHAR(256) NOT NULL, - `full_doc_id` VARCHAR(256) NOT NULL, - `workspace` varchar(1024), - `chunk_order_index` INT, - `tokens` INT, - `content` LONGTEXT, - `content_vector` VECTOR, - `createtime` TIMESTAMP, - `updatetime` TIMESTAMP, - UNIQUE KEY (`chunk_id`) - ); - """ - }, - "LIGHTRAG_GRAPH_NODES": { - "ddl": """ - CREATE TABLE LIGHTRAG_GRAPH_NODES ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `entity_id` VARCHAR(256), - `workspace` varchar(1024), - `name` VARCHAR(2048), - `entity_type` VARCHAR(1024), - `description` LONGTEXT, - `source_chunk_id` VARCHAR(256), - `content` LONGTEXT, - `content_vector` VECTOR, - `createtime` TIMESTAMP, - `updatetime` TIMESTAMP, - KEY (`entity_id`) - ); - """ - }, - "LIGHTRAG_GRAPH_EDGES": { - "ddl": """ - CREATE TABLE LIGHTRAG_GRAPH_EDGES ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `relation_id` VARCHAR(256), - `workspace` varchar(1024), - `source_name` VARCHAR(2048), - `target_name` VARCHAR(2048), - `weight` DECIMAL, - `keywords` TEXT, - `description` LONGTEXT, - `source_chunk_id` varchar(256), - `content` LONGTEXT, - `content_vector` VECTOR, - `createtime` TIMESTAMP, - `updatetime` TIMESTAMP, - KEY (`relation_id`) - ); - """ - }, - "LIGHTRAG_LLM_CACHE": { - "ddl": """ - CREATE TABLE LIGHTRAG_LLM_CACHE ( - id BIGINT PRIMARY KEY AUTO_INCREMENT, - send TEXT, - return TEXT, - model VARCHAR(1024), - createtime DATETIME DEFAULT CURRENT_TIMESTAMP, - updatetime DATETIME DEFAULT NULL - ); - """ - }, -} - - -SQL_TEMPLATES = { - # SQL for KVStorage - "get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace", - "get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace", - "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace", - "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace", - "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace", - # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE) - "upsert_doc_full": """ - INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) - VALUES (:id, :content, :workspace) - ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP - """, - "upsert_chunk": """ - INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace, createtime, updatetime) - VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) - ON DUPLICATE KEY UPDATE - content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), - full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = FROM_UNIXTIME(:timestamp) - """, - # SQL for VectorStorage - "entities": """SELECT n.name as entity_name, UNIX_TIMESTAMP(n.createtime) as created_at FROM - (SELECT entity_id as id, name, createtime, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance - FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n - WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k - """, - "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id, UNIX_TIMESTAMP(e.createtime) as created_at FROM - (SELECT source_name, target_name, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance - FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e - WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k - """, - "chunks": """SELECT c.id, UNIX_TIMESTAMP(c.createtime) as created_at FROM - (SELECT chunk_id as id, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c - WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k - """, - "has_entity": """ - SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace - """, - "has_relationship": """ - SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace - """, - "upsert_entity": """ - INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace, createtime, updatetime) - VALUES(:id, :name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) - ON DUPLICATE KEY UPDATE - content = VALUES(content), - content_vector = VALUES(content_vector), - updatetime = FROM_UNIXTIME(:timestamp) - """, - "upsert_relationship": """ - INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace, createtime, updatetime) - VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) - ON DUPLICATE KEY UPDATE - content = VALUES(content), - content_vector = VALUES(content_vector), - updatetime = FROM_UNIXTIME(:timestamp) - """, - # SQL for GraphStorage - "get_node": """ - SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector - FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace - """, - "get_edge": """ - SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector - FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace - """, - "get_node_edges": """ - SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector - FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace - """, - "node_degree": """ - SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name) - """, - "upsert_node": """ - INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description) - VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description) - ON DUPLICATE KEY UPDATE - name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), - workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, - source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description) - """, - "upsert_edge": """ - INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector, - workspace, weight, keywords, description, source_chunk_id) - VALUES(:source_name, :target_name, :content, :content_vector, - :workspace, :weight, :keywords, :description, :source_chunk_id) - ON DUPLICATE KEY UPDATE - source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), - content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, - weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), - source_chunk_id = VALUES(source_chunk_id) - """, - "delete_node": """ - DELETE FROM LIGHTRAG_GRAPH_NODES - WHERE name = :name AND workspace = :workspace - """, - "delete_node_edges": """ - DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace - """, - "get_all_labels": """ - SELECT DISTINCT entity_type as label - FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace - ORDER BY entity_type - """, - "get_matching_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES - WHERE name LIKE :label_pattern AND workspace = :workspace - ORDER BY name - """, - "get_all_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace - ORDER BY name - LIMIT :max_nodes - """, - "get_related_edges": """ - SELECT * FROM LIGHTRAG_GRAPH_EDGES - WHERE (source_name IN (:node_names) OR target_name IN (:node_names)) - AND workspace = :workspace - """, - "remove_multiple_edges": """ - DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE (source_name = :source AND target_name = :target) - AND workspace = :workspace - """, - # Drop tables - "drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace", -} From 2dab4e321da38eb7c6dffa240fa37e9c568e0129 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 6 Aug 2025 01:03:35 +0800 Subject: [PATCH 8/8] Bump api version to 0199 --- lightrag/api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index f3602985..48c39b83 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "0198" +__api_version__ = "0199"