Merge pull request #1915 from danielaskdd/optimize-llm-cache

Refact: Optimized LLM Cache Hash Key Generation by Including All Query Parameters
This commit is contained in:
Daniel.y 2025-08-06 01:04:02 +08:00 committed by GitHub
commit a6ef29cef6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 298 additions and 3217 deletions

View file

@ -737,7 +737,54 @@ rag.insert(documents, file_paths=file_paths)
### 存储 ### 存储
LightRAG使用到4种类型的存储每一种存储都有多种实现方案。在初始化LightRAG的时候可以通过参数设定这四类存储的实现方案。详情请参看前面的LightRAG初始化参数。 LightRAG 使用 4 种类型的存储用于不同目的:
* KV_STORAGEllm 响应缓存、文本块、文档信息
* VECTOR_STORAGE实体向量、关系向量、块向量
* GRAPH_STORAGE实体关系图
* DOC_STATUS_STORAGE文档索引状态
每种存储类型都有几种实现:
* KV_STORAGE 支持的实现名称
```
JsonKVStorage JsonFile(默认)
PGKVStorage Postgres
RedisKVStorage Redis
MongoKVStorage MogonDB
```
* GRAPH_STORAGE 支持的实现名称
```
NetworkXStorage NetworkX(默认)
Neo4JStorage Neo4J
PGGraphStorage PostgreSQL with AGE plugin
```
> 在测试中Neo4j图形数据库相比PostgreSQL AGE有更好的性能表现。
* VECTOR_STORAGE 支持的实现名称
```
NanoVectorDBStorage NanoVector(默认)
PGVectorStorage Postgres
MilvusVectorDBStorge Milvus
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
MongoVectorDBStorage MongoDB
```
* DOC_STATUS_STORAGE 支持的实现名称
```
JsonDocStatusStorage JsonFile(默认)
PGDocStatusStorage Postgres
MongoDocStatusStorage MongoDB
```
每一种存储类型的链接配置范例可以在 `env.example` 文件中找到。链接字符串中的数据库实例是需要你预先在数据库服务器上创建好的LightRAG 仅负责在数据库实例中创建数据表,不负责创建数据库实例。如果使 Redis 作为存储,记得给 Redis 配置自动持久化数据规则,否则 Redis 服务重启后数据会丢失。如果使用PostgreSQL数据库推荐使用16.6版本或以上。
<details> <details>
<summary> <b>使用Neo4J进行存储</b> </summary> <summary> <b>使用Neo4J进行存储</b> </summary>

View file

@ -747,7 +747,55 @@ rag.insert(documents, file_paths=file_paths)
### Storage ### Storage
LightRAG uses four types of storage, each of which has multiple implementation options. When initializing LightRAG, the implementation schemes for these four types of storage can be set through parameters. For details, please refer to the previous LightRAG initialization parameters. LightRAG uses 4 types of storage for different purposes:
* KV_STORAGE: llm response cache, text chunks, document information
* VECTOR_STORAGE: entities vectors, relation vectors, chunks vectors
* GRAPH_STORAGE: entity relation graph
* DOC_STATUS_STORAGE: document indexing status
Each storage type has several implementations:
* KV_STORAGE supported implementations:
```
JsonKVStorage JsonFile (default)
PGKVStorage Postgres
RedisKVStorage Redis
MongoKVStorage MongoDB
```
* GRAPH_STORAGE supported implementations:
```
NetworkXStorage NetworkX (default)
Neo4JStorage Neo4J
PGGraphStorage PostgreSQL with AGE plugin
MemgraphStorage. Memgraph
```
> Testing has shown that Neo4J delivers superior performance in production environments compared to PostgreSQL with AGE plugin.
* VECTOR_STORAGE supported implementations:
```
NanoVectorDBStorage NanoVector (default)
PGVectorStorage Postgres
MilvusVectorDBStorage Milvus
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
MongoVectorDBStorage MongoDB
```
* DOC_STATUS_STORAGE: supported implementations:
```
JsonDocStatusStorage JsonFile (default)
PGDocStatusStorage Postgres
MongoDocStatusStorage MongoDB
```
Example connection configurations for each storage type can be found in the `env.example` file. The database instance in the connection string needs to be created by you on the database server beforehand. LightRAG is only responsible for creating tables within the database instance, not for creating the database instance itself. If using Redis as storage, remember to configure automatic data persistence rules for Redis, otherwise data will be lost after the Redis service restarts. If using PostgreSQL, it is recommended to use version 16.6 or above.
<details> <details>
<summary> <b>Using Neo4J for Storage</b> </summary> <summary> <b>Using Neo4J for Storage</b> </summary>

View file

@ -409,7 +409,6 @@ PGGraphStorage PostgreSQL with AGE plugin
NanoVectorDBStorage NanoVector(默认) NanoVectorDBStorage NanoVector(默认)
PGVectorStorage Postgres PGVectorStorage Postgres
MilvusVectorDBStorge Milvus MilvusVectorDBStorge Milvus
ChromaVectorDBStorage Chroma
FaissVectorDBStorage Faiss FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant QdrantVectorDBStorage Qdrant
MongoVectorDBStorage MongoDB MongoVectorDBStorage MongoDB

View file

@ -412,7 +412,6 @@ MemgraphStorage. Memgraph
NanoVectorDBStorage NanoVector (default) NanoVectorDBStorage NanoVector (default)
PGVectorStorage Postgres PGVectorStorage Postgres
MilvusVectorDBStorage Milvus MilvusVectorDBStorage Milvus
ChromaVectorDBStorage Chroma
FaissVectorDBStorage Faiss FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant QdrantVectorDBStorage Qdrant
MongoVectorDBStorage MongoDB MongoVectorDBStorage MongoDB

View file

@ -1 +1 @@
__api_version__ = "0198" __api_version__ = "0199"

View file

@ -258,19 +258,12 @@ class ClearDocumentsResponse(BaseModel):
class ClearCacheRequest(BaseModel): class ClearCacheRequest(BaseModel):
"""Request model for clearing cache """Request model for clearing cache
Attributes: This model is kept for API compatibility but no longer accepts any parameters.
modes: Optional list of cache modes to clear All cache will be cleared regardless of the request content.
""" """
modes: Optional[
List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
] = Field(
default=None,
description="Modes of cache to clear. If None, clears all cache.",
)
class Config: class Config:
json_schema_extra = {"example": {"modes": ["default", "naive"]}} json_schema_extra = {"example": {}}
class ClearCacheResponse(BaseModel): class ClearCacheResponse(BaseModel):
@ -1820,47 +1813,28 @@ def create_document_routes(
) )
async def clear_cache(request: ClearCacheRequest): async def clear_cache(request: ClearCacheRequest):
""" """
Clear cache data from the LLM response cache storage. Clear all cache data from the LLM response cache storage.
This endpoint allows clearing specific modes of cache or all cache if no modes are specified. This endpoint clears all cached LLM responses regardless of mode.
Valid modes include: "default", "naive", "local", "global", "hybrid", "mix". The request body is accepted for API compatibility but is ignored.
- "default" represents extraction cache.
- Other modes correspond to different query modes.
Args: Args:
request (ClearCacheRequest): The request body containing optional modes to clear. request (ClearCacheRequest): The request body (ignored for compatibility).
Returns: Returns:
ClearCacheResponse: A response object containing the status and message. ClearCacheResponse: A response object containing the status and message.
Raises: Raises:
HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors). HTTPException: If an error occurs during cache clearing (500).
""" """
try: try:
# Validate modes if provided # Call the aclear_cache method (no modes parameter)
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"] await rag.aclear_cache()
if request.modes and not all(mode in valid_modes for mode in request.modes):
invalid_modes = [
mode for mode in request.modes if mode not in valid_modes
]
raise HTTPException(
status_code=400,
detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
)
# Call the aclear_cache method
await rag.aclear_cache(request.modes)
# Prepare success message # Prepare success message
if request.modes: message = "Successfully cleared all cache"
message = f"Successfully cleared cache for modes: {request.modes}"
else:
message = "Successfully cleared all cache"
return ClearCacheResponse(status="success", message=message) return ClearCacheResponse(status="success", message=message)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e: except Exception as e:
logger.error(f"Error clearing cache: {str(e)}") logger.error(f"Error clearing cache: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())

View file

@ -331,21 +331,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
None None
""" """
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed, or the cache mode is not supported
"""
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace, ABC): class BaseGraphStorage(StorageNameSpace, ABC):
@ -761,10 +746,6 @@ class DocStatusStorage(BaseKVStorage, ABC):
Dictionary mapping status names to counts Dictionary mapping status names to counts
""" """
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Drop cache is not supported for Doc Status storage"""
return False
class StoragesStatus(str, Enum): class StoragesStatus(str, Enum):
"""Storages status""" """Storages status"""

View file

@ -5,7 +5,6 @@ STORAGE_IMPLEMENTATIONS = {
"RedisKVStorage", "RedisKVStorage",
"PGKVStorage", "PGKVStorage",
"MongoKVStorage", "MongoKVStorage",
# "TiDBKVStorage",
], ],
"required_methods": ["get_by_id", "upsert"], "required_methods": ["get_by_id", "upsert"],
}, },
@ -16,9 +15,6 @@ STORAGE_IMPLEMENTATIONS = {
"PGGraphStorage", "PGGraphStorage",
"MongoGraphStorage", "MongoGraphStorage",
"MemgraphStorage", "MemgraphStorage",
# "AGEStorage",
# "TiDBGraphStorage",
# "GremlinStorage",
], ],
"required_methods": ["upsert_node", "upsert_edge"], "required_methods": ["upsert_node", "upsert_edge"],
}, },
@ -31,7 +27,6 @@ STORAGE_IMPLEMENTATIONS = {
"QdrantVectorDBStorage", "QdrantVectorDBStorage",
"MongoVectorDBStorage", "MongoVectorDBStorage",
# "ChromaVectorDBStorage", # "ChromaVectorDBStorage",
# "TiDBVectorDBStorage",
], ],
"required_methods": ["query", "upsert"], "required_methods": ["query", "upsert"],
}, },
@ -52,20 +47,17 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"JsonKVStorage": [], "JsonKVStorage": [],
"MongoKVStorage": [], "MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"], "RedisKVStorage": ["REDIS_URI"],
# "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
# Graph Storage Implementations # Graph Storage Implementations
"NetworkXStorage": [], "NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [], "MongoGraphStorage": [],
"MemgraphStorage": ["MEMGRAPH_URI"], "MemgraphStorage": ["MEMGRAPH_URI"],
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [ "AGEStorage": [
"AGE_POSTGRES_DB", "AGE_POSTGRES_DB",
"AGE_POSTGRES_USER", "AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD", "AGE_POSTGRES_PASSWORD",
], ],
# "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [ "PGGraphStorage": [
"POSTGRES_USER", "POSTGRES_USER",
"POSTGRES_PASSWORD", "POSTGRES_PASSWORD",
@ -75,7 +67,6 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"NanoVectorDBStorage": [], "NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [], "MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [], "ChromaVectorDBStorage": [],
# "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [], "FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
@ -102,14 +93,10 @@ STORAGES = {
"RedisKVStorage": ".kg.redis_impl", "RedisKVStorage": ".kg.redis_impl",
"RedisDocStatusStorage": ".kg.redis_impl", "RedisDocStatusStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",
# "TiDBKVStorage": ".kg.tidb_impl",
# "TiDBVectorDBStorage": ".kg.tidb_impl",
# "TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl", "PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl", "PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl", "AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl", "PGGraphStorage": ".kg.postgres_impl",
# "GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl", "PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl", "FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl",

View file

@ -1,867 +0,0 @@
import asyncio
import inspect
import json
import os
import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import pipmaster as pm
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from lightrag.utils import logger
from ..base import BaseGraphStorage
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import psycopg # type: ignore
from psycopg.rows import namedtuple_row # type: ignore
from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
class AGEQueryException(Exception):
"""Exception for the AGE queries."""
def __init__(self, exception: Union[str, Dict]) -> None:
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
else:
self.message = exception
self.details = "unknown"
def get_message(self) -> str:
return self.message
def get_details(self) -> Any:
return self.details
@final
@dataclass
class AGEStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
self._driver_lock = asyncio.Lock()
DB = os.environ["AGE_POSTGRES_DB"].replace("\\", "\\\\").replace("'", "\\'")
USER = os.environ["AGE_POSTGRES_USER"].replace("\\", "\\\\").replace("'", "\\'")
PASSWORD = (
os.environ["AGE_POSTGRES_PASSWORD"]
.replace("\\", "\\\\")
.replace("'", "\\'")
)
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
PORT = os.environ.get("AGE_POSTGRES_PORT", "8529")
self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
self._driver = AsyncConnectionPool(connection_string, open=False)
return None
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
@staticmethod
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
"""
Convert a record returned from an age query to a dictionary
Args:
record (): a record from an age query result
Returns:
Dict[str, Any]: a dictionary representation of the record where
the dictionary key is the field name and the value is the
value converted to a python type
"""
# result holder
d = {}
# prebuild a mapping of vertex_id to vertex mappings to be used
# later to build edges
vertices = {}
for k in record._fields:
v = getattr(record, k)
# agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
# iterate returned fields and parse appropriately
for k in record._fields:
v = getattr(record, k)
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
else:
dtype = ""
if dtype == "vertex":
vertex = json.loads(v)
field = json.loads(v).get("properties")
if not field:
field = {}
field["label"] = AGEStorage._decode_graph_label(vertex["label"])
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
else:
d[k] = json.loads(v) if isinstance(v, str) else v
return d
@staticmethod
def _format_properties(
properties: Dict[str, Any], _id: Union[str, None] = None
) -> str:
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert/merge statement.
Args:
properties (Dict[str,str]): a dictionary containing node/edge properties
id (Union[str, None]): the id of the node or None if none exists
Returns:
str: the properties dictionary as a properly formatted string
"""
props = []
# wrap property key in backticks to escape
for k, v in properties.items():
prop = f"`{k}`: {json.dumps(v)}"
props.append(prop)
if _id is not None and "id" not in properties:
props.append(
f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}"
)
return "{" + ", ".join(props) + "}"
@staticmethod
def _encode_graph_label(label: str) -> str:
"""
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
Args:
label (str): the original label
Returns:
str: the encoded label
"""
return "x" + label.encode().hex()
@staticmethod
def _decode_graph_label(encoded_label: str) -> str:
"""
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
Args:
encoded_label (str): the encoded label
Returns:
str: the decoded label
"""
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
@staticmethod
def _get_col_name(field: str, idx: int) -> str:
"""
Convert a cypher return field to a pgsql select field
If possible keep the cypher column name, but create a generic name if necessary
Args:
field (str): a return field from a cypher query to be formatted for pgsql
idx (int): the position of the field in the return statement
Returns:
str: the field to be used in the pgsql select statement
"""
# remove white space
field = field.strip()
# if an alias is provided for the field, use it
if " as " in field:
return field.split(" as ")[-1].strip()
# if the return value is an unnamed primitive, give it a generic name
if field.isnumeric() or field in ("true", "false", "null"):
return f"column_{idx}"
# otherwise return the value stripping out some common special chars
return field.replace("(", "_").replace(")", "")
@staticmethod
def _wrap_query(query: str, graph_name: str, **params: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
params (dict): parameters for the query
Returns:
str: an equivalent pgsql query
"""
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""
# if there are any returned fields they must be added to the pgsql query
if "return" in query.lower():
# parse return statement to identify returned fields
fields = (
query.lower()
.split("return")[-1]
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
)
# get pgsql formatted field names
fields = [
AGEStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
]
# build resulting pgsql relation
fields_str = ", ".join(
[field.split(".")[-1] + " agtype" for field in fields]
)
# if no return statement we still need to return a single field of type agtype
else:
fields_str = "a agtype"
select_str = "*"
return template.format(
graph_name=graph_name,
query=query.format(**params),
fields=fields_str,
projection=select_str,
)
async def _query(self, query: str, **params: str) -> List[Dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result
Args:
query (str): a cypher query to be executed
params (dict): parameters for the query
Returns:
List[Dict[str, Any]]: a list of dictionaries containing the result set
"""
# convert cypher query to pgsql/age query
wrapped_query = self._wrap_query(query, self.graph_name, **params)
await self._driver.open()
# create graph if it doesn't exist
async with self._get_pool_connection() as conn:
async with conn.cursor() as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute(f"SELECT create_graph('{self.graph_name}')")
await conn.commit()
except (
psycopg.errors.InvalidSchemaName,
psycopg.errors.UniqueViolation,
):
await conn.rollback()
# execute the query, rolling back on an error
async with self._get_pool_connection() as conn:
async with conn.cursor(row_factory=namedtuple_row) as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute(wrapped_query)
await conn.commit()
except psycopg.Error as e:
await conn.rollback()
raise AGEQueryException(
{
"message": f"Error executing graph query: {query.format(**params)}",
"detail": str(e),
}
) from e
data = await curs.fetchall()
if data is None:
result = []
# decode records
else:
result = [AGEStorage._record_to_dict(d) for d in data]
return result
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
single_result = (await self._query(query, **params))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
single_result["node_exists"],
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
query = """
MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
RETURN COUNT(r) > 0 AS edge_exists
"""
params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
}
single_result = (await self._query(query, **params))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
single_result["edge_exists"],
)
return single_result["edge_exists"]
async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`) RETURN n
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
record = await self._query(query, **params)
if record:
node = record[0]
node_dict = node["n"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
node_dict,
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`)-[]->(x)
RETURN count(x) AS total_edge_count
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
record = (await self._query(query, **params))[0]
if record:
edge_count = int(record["total_edge_count"])
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
edge_count,
)
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
"{%s}:query:src_Degree+trg_degree:result:{%s}",
inspect.currentframe().f_code.co_name,
degrees,
)
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
query = """
MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
RETURN properties(r) as edge_properties
LIMIT 1
"""
params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
}
record = await self._query(query, **params)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
result,
)
return result
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
"""
node_label = source_node_id.strip('"')
query = """
MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected
"""
params = {"label": AGEStorage._encode_graph_label(node_label)}
results = await self._query(query, **params)
edges = []
for record in results:
source_node = record["n"] if record["n"] else None
connected_node = record["connected"] if record["connected"] else None
source_label = (
source_node["label"] if source_node and source_node["label"] else None
)
target_label = (
connected_node["label"]
if connected_node and connected_node["label"]
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((AGEQueryException,)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the AGE database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('"')
properties = node_data
query = """
MERGE (n:`{label}`)
SET n += {properties}
"""
params = {
"label": AGEStorage._encode_graph_label(label),
"properties": AGEStorage._format_properties(properties),
}
try:
await self._query(query, **params)
logger.debug(
"Upserted node with label '{%s}' and properties: {%s}",
label,
properties,
)
except Exception as e:
logger.error("Error during upsert: {%s}", e)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((AGEQueryException,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their labels.
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
query = """
MATCH (source:`{src_label}`)
WITH source
MATCH (target:`{tgt_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += {properties}
RETURN r
"""
params = {
"src_label": AGEStorage._encode_graph_label(source_node_label),
"tgt_label": AGEStorage._encode_graph_label(target_node_label),
"properties": AGEStorage._format_properties(edge_properties),
}
try:
await self._query(query, **params)
logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
source_node_label,
target_node_label,
edge_properties,
)
except Exception as e:
logger.error("Error during edge upsert: {%s}", e)
raise
@asynccontextmanager
async def _get_pool_connection(self, timeout: Optional[float] = None):
"""Workaround for a psycopg_pool bug"""
try:
connection = await self._driver.getconn(timeout=timeout)
except PoolTimeout:
await self._driver._add_connection(None) # workaround...
connection = await self._driver.getconn(timeout=timeout)
try:
async with connection:
yield connection
finally:
await self._driver.putconn(connection)
async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified label
Args:
node_id: The label of the node to delete
"""
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`)
DETACH DELETE n
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
try:
await self._query(query, **params)
logger.debug(f"Deleted node with label '{entity_name_label}'")
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node labels to be deleted
"""
for node in nodes:
await self.delete_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
entity_name_label_source = source.strip('"')
entity_name_label_target = target.strip('"')
query = """
MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`)
DELETE r
"""
params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
}
try:
await self._query(query, **params)
logger.debug(
f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'"
)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def get_all_labels(self) -> list[str]:
"""Get all node labels in the database
Returns:
["label1", "label2", ...] # Alphabetically sorted label list
"""
query = """
MATCH (n)
RETURN DISTINCT labels(n) AS node_labels
"""
results = await self._query(query)
all_labels = []
for record in results:
if record and "node_labels" in record:
for label in record["node_labels"]:
if label:
# Decode label
decoded_label = AGEStorage._decode_graph_label(label)
all_labels.append(decoded_label)
# Remove duplicates and sort
return sorted(list(set(all_labels)))
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified 'node_label'.
Maximum number of nodes is constrained by the environment variable 'MAX_GRAPH_NODES' (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence (nodes containing the specified label string)
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Args:
node_label: String to match in node labels (will match any node containing this string in its label)
max_depth: Maximum depth of the graph. Defaults to 5.
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
"""
max_graph_nodes = int(os.getenv("MAX_GRAPH_NODES", 1000))
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
# Handle special case for "*" label
if node_label == "*":
# Query all nodes and sort by degree
query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, count(r) AS degree
ORDER BY degree DESC
LIMIT {max_nodes}
RETURN n, degree
"""
params = {"max_nodes": max_graph_nodes}
nodes_result = await self._query(query, **params)
# Add nodes to result
node_ids = []
for record in nodes_result:
if "n" in record:
node = record["n"]
node_id = str(node.get("id", ""))
if node_id not in seen_nodes:
node_properties = {k: v for k, v in node.items()}
node_label = node.get("label", "")
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_label],
properties=node_properties,
)
)
seen_nodes.add(node_id)
node_ids.append(node_id)
# Query edges between these nodes
if node_ids:
edges_query = """
MATCH (a)-[r]->(b)
WHERE a.id IN {node_ids} AND b.id IN {node_ids}
RETURN a, r, b
"""
edges_params = {"node_ids": node_ids}
edges_result = await self._query(edges_query, **edges_params)
# Add edges to result
for record in edges_result:
if "r" in record and "a" in record and "b" in record:
source = record["a"].get("id", "")
target = record["b"].get("id", "")
edge_id = f"{source}-{target}"
if edge_id not in seen_edges:
edge_properties = {k: v for k, v in record["r"].items()}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=source,
target=target,
properties=edge_properties,
)
)
seen_edges.add(edge_id)
else:
# For specific label, use partial matching
entity_name_label = node_label.strip('"')
encoded_label = AGEStorage._encode_graph_label(entity_name_label)
# Find matching start nodes
start_query = """
MATCH (n:`{label}`)
RETURN n
"""
start_params = {"label": encoded_label}
start_nodes = await self._query(start_query, **start_params)
if not start_nodes:
logger.warning(f"No nodes found with label '{entity_name_label}'!")
return result
# Traverse graph from each start node
for start_node_record in start_nodes:
if "n" in start_node_record:
# Use BFS to traverse graph
query = """
MATCH (start:`{label}`)
CALL {
MATCH path = (start)-[*0..{max_depth}]->(n)
RETURN nodes(path) AS path_nodes, relationships(path) AS path_rels
}
RETURN DISTINCT path_nodes, path_rels
"""
params = {"label": encoded_label, "max_depth": max_depth}
results = await self._query(query, **params)
# Extract nodes and edges from results
for record in results:
if "path_nodes" in record:
# Process nodes
for node in record["path_nodes"]:
node_id = str(node.get("id", ""))
if (
node_id not in seen_nodes
and len(seen_nodes) < max_graph_nodes
):
node_properties = {k: v for k, v in node.items()}
node_label = node.get("label", "")
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_label],
properties=node_properties,
)
)
seen_nodes.add(node_id)
if "path_rels" in record:
# Process edges
for rel in record["path_rels"]:
source = str(rel.get("start_id", ""))
target = str(rel.get("end_id", ""))
edge_id = f"{source}-{target}"
if edge_id not in seen_edges:
edge_properties = {k: v for k, v in rel.items()}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=rel.get("label", "DIRECTED"),
source=source,
target=target,
properties=edge_properties,
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def index_done_callback(self) -> None:
# AGES handles persistence automatically
pass
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all nodes and relationships in the graph.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
query = """
MATCH (n)
DETACH DELETE n
"""
await self._query(query)
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.graph_name}: {e}")
return {"status": "error", "message": str(e)}

View file

@ -1,686 +0,0 @@
import asyncio
import inspect
import json
import os
import pipmaster as pm
from dataclasses import dataclass
from typing import Any, Dict, List, final
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from ..base import BaseGraphStorage
if not pm.is_installed("gremlinpython"):
pm.install("gremlinpython")
from gremlin_python.driver import client, serializer # type: ignore
from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
from gremlin_python.driver.protocol import GremlinServerError # type: ignore
@final
@dataclass
class GremlinStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with Gremlin in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
self._driver_lock = asyncio.Lock()
USER = os.environ.get("GREMLIN_USER", "")
PASSWORD = os.environ.get("GREMLIN_PASSWORD", "")
HOST = os.environ["GREMLIN_HOST"]
PORT = int(os.environ["GREMLIN_PORT"])
# TraversalSource, a custom one has to be created manually,
# default it "g"
SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g")
# All vertices will have graph={GRAPH} property, so that we can
# have several logical graphs for one source
GRAPH = GremlinStorage._to_value_map(
os.environ.get("GREMLIN_GRAPH", "LightRAG")
)
self.graph_name = GRAPH
self._driver = client.Client(
f"ws://{HOST}:{PORT}/gremlin",
SOURCE,
username=USER,
password=PASSWORD,
message_serializer=serializer.GraphSONSerializersV3d0(),
transport_factory=lambda: AiohttpTransport(call_from_event_loop=True),
)
async def close(self):
if self._driver:
self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
self._driver.close()
async def index_done_callback(self) -> None:
# Gremlin handles persistence automatically
pass
@staticmethod
def _to_value_map(value: Any) -> str:
"""Dump supported Python object as Gremlin valueMap"""
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
parsed_str = json_str.replace("'", r"\'")
# walk over the string and replace curly brackets with square brackets
# outside of strings, as well as replace double quotes with single quotes
# and "deescape" double quotes inside of strings
outside_str = True
escaped = False
remove_indices = []
for i, c in enumerate(parsed_str):
if escaped:
# previous character was an "odd" backslash
escaped = False
if c == '"':
# we want to "deescape" double quotes: store indices to delete
remove_indices.insert(0, i - 1)
elif c == "\\":
escaped = True
elif c == '"':
outside_str = not outside_str
parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :]
elif c == "{" and outside_str:
parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :]
elif c == "}" and outside_str:
parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :]
for idx in remove_indices:
parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :]
return parsed_str
@staticmethod
def _convert_properties(properties: Dict[str, Any]) -> str:
"""Create chained .property() commands from properties dict"""
props = []
for k, v in properties.items():
prop_name = GremlinStorage._to_value_map(k)
props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
return "".join(props)
@staticmethod
def _fix_name(name: str) -> str:
"""Strip double quotes and format as a proper field name"""
name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))
return name
async def _query(self, query: str) -> List[Dict[str, Any]]:
"""
Query the Gremlin graph
Args:
query (str): a query to be executed
Returns:
List[Dict[str, Any]]: a list of dictionaries containing the result set
"""
result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
if result:
result = result[0]
return result
async def has_node(self, node_id: str) -> bool:
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.limit(1)
.count()
.project('has_node')
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
"""
result = await self._query(query)
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result[0]["has_node"],
)
return result[0]["has_node"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.outE()
.inV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target})
.limit(1)
.count()
.project('has_edge')
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
"""
result = await self._query(query)
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result[0]["has_edge"],
)
return result[0]["has_edge"]
async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.limit(1)
.project('properties')
.by(elementMap())
"""
result = await self._query(query)
if result:
node = result[0]
node_dict = node["properties"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name,
query.format,
node_dict,
)
return node_dict
async def node_degree(self, node_id: str) -> int:
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.outE()
.inV().has('graph', {self.graph_name})
.count()
.project('total_edge_count')
.by()
"""
result = await self._query(query)
edge_count = result[0]["total_edge_count"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
edge_count,
)
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
"{%s}:query:src_Degree+trg_degree:result:{%s}",
inspect.currentframe().f_code.co_name,
degrees,
)
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.outE()
.inV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target})
.limit(1)
.project('edge_properties')
.by(__.bothE().elementMap())
"""
result = await self._query(query)
if result:
edge_properties = result[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
edge_properties,
)
return edge_properties
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
node_name = GremlinStorage._fix_name(source_node_id)
query = f"""g
.E()
.filter(
__.or(
__.outV().has('graph', {self.graph_name})
.has('entity_name', {node_name}),
__.inV().has('graph', {self.graph_name})
.has('entity_name', {node_name})
)
)
.project('source_name', 'target_name')
.by(__.outV().values('entity_name'))
.by(__.inV().values('entity_name'))
"""
result = await self._query(query)
edges = [(res["source_name"], res["target_name"]) for res in result]
return edges
@retry(
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the Gremlin graph.
Args:
node_id: The unique identifier for the node (used as name)
node_data: Dictionary of node properties
"""
name = GremlinStorage._fix_name(node_id)
properties = GremlinStorage._convert_properties(node_data)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {name})
.fold()
.coalesce(
__.unfold(),
__.addV('ENTITY')
.property('graph', {self.graph_name})
.property('entity_name', {name})
)
{properties}
"""
try:
await self._query(query)
logger.debug(
"Upserted node with name {%s} and properties: {%s}",
name,
properties,
)
except Exception as e:
logger.error("Error during upsert: {%s}", e)
raise
@retry(
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their names.
Args:
source_node_id (str): Name of the source node (used as identifier)
target_node_id (str): Name of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_name = GremlinStorage._fix_name(source_node_id)
target_node_name = GremlinStorage._fix_name(target_node_id)
edge_properties = GremlinStorage._convert_properties(edge_data)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {source_node_name}).as('source')
.V().has('graph', {self.graph_name})
.has('entity_name', {target_node_name}).as('target')
.coalesce(
__.select('source').outE('DIRECTED').where(__.inV().as('target')),
__.select('source').addE('DIRECTED').to(__.select('target'))
)
.property('graph', {self.graph_name})
{edge_properties}
"""
try:
await self._query(query)
logger.debug(
"Upserted edge from {%s} to {%s} with properties: {%s}",
source_node_name,
target_node_name,
edge_properties,
)
except Exception as e:
logger.error("Error during edge upsert: {%s}", e)
raise
async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified entity_name
Args:
node_id: The entity_name of the node to delete
"""
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.drop()
"""
try:
await self._query(query)
logger.debug(
"{%s}: Deleted node with entity_name '%s'",
inspect.currentframe().f_code.co_name,
entity_name,
)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
async def get_all_labels(self) -> list[str]:
"""
Get all node entity_names in the graph
Returns:
[entity_name1, entity_name2, ...] # Alphabetically sorted entity_name list
"""
query = f"""g
.V().has('graph', {self.graph_name})
.values('entity_name')
.dedup()
.order()
"""
try:
result = await self._query(query)
labels = result if result else []
logger.debug(
"{%s}: Retrieved %d labels",
inspect.currentframe().f_code.co_name,
len(labels),
)
return labels
except Exception as e:
logger.error(f"Error retrieving labels: {str(e)}")
return []
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
Args:
node_label: Entity name of the starting node
max_depth: Maximum depth of the subgraph
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
entity_name = GremlinStorage._fix_name(node_label)
# Handle special case for "*" label
if node_label == "*":
# For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES)
query = f"""g
.V().has('graph', {self.graph_name})
.limit({MAX_GRAPH_NODES})
.elementMap()
"""
nodes_result = await self._query(query)
# Add nodes to result
for node_data in nodes_result:
node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes:
continue
# Create node with properties
node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append(
KnowledgeGraphNode(
id=str(node_id),
labels=[str(node_id)],
properties=node_properties,
)
)
seen_nodes.add(str(node_id))
# Get and add edges
if nodes_result:
query = f"""g
.V().has('graph', {self.graph_name})
.limit({MAX_GRAPH_NODES})
.outE()
.inV().has('graph', {self.graph_name})
.limit({MAX_GRAPH_NODES})
.path()
.by(elementMap())
.by(elementMap())
.by(elementMap())
"""
edges_result = await self._query(query)
for path in edges_result:
if len(path) >= 3: # source -> edge -> target
source = path[0]
edge_data = path[1]
target = path[2]
source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges:
continue
# Create edge with properties
edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source_id),
target=str(target_id),
properties=edge_properties,
)
)
seen_edges.add(edge_id)
else:
# Search for specific node and get its neighborhood
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.repeat(__.both().simplePath().dedup())
.times({max_depth})
.emit()
.dedup()
.limit({MAX_GRAPH_NODES})
.elementMap()
"""
nodes_result = await self._query(query)
# Add nodes to result
for node_data in nodes_result:
node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes:
continue
# Create node with properties
node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append(
KnowledgeGraphNode(
id=str(node_id),
labels=[str(node_id)],
properties=node_properties,
)
)
seen_nodes.add(str(node_id))
# Get edges between the nodes in the result
if nodes_result:
node_ids = [
n.get("entity_name", str(n.get("id", ""))) for n in nodes_result
]
node_ids_query = ", ".join(
[GremlinStorage._to_value_map(nid) for nid in node_ids]
)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', within({node_ids_query}))
.outE()
.where(inV().has('graph', {self.graph_name})
.has('entity_name', within({node_ids_query})))
.path()
.by(elementMap())
.by(elementMap())
.by(elementMap())
"""
edges_result = await self._query(query)
for path in edges_result:
if len(path) >= 3: # source -> edge -> target
source = path[0]
edge_data = path[1]
target = path[2]
source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges:
continue
# Create edge with properties
edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source_id),
target=str(target_id),
properties=edge_properties,
)
)
seen_edges.add(edge_id)
logger.info(
"Subgraph query successful | Node count: %d | Edge count: %d",
len(result.nodes),
len(result.edges),
)
return result
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node entity_names to be deleted
"""
for node in nodes:
await self.delete_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
entity_name_source = GremlinStorage._fix_name(source)
entity_name_target = GremlinStorage._fix_name(target)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.outE()
.where(inV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target}))
.drop()
"""
try:
await self._query(query)
logger.debug(
"{%s}: Deleted edge from '%s' to '%s'",
inspect.currentframe().f_code.co_name,
entity_name_source,
entity_name_target,
)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all nodes and relationships in the graph.
This function deletes all nodes with the specified graph name property,
which automatically removes all associated edges.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
query = f"""g
.V().has('graph', {self.graph_name})
.drop()
"""
await self._query(query)
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.graph_name}: {e}")
return {"status": "error", "message": str(e)}

File diff suppressed because it is too large Load diff

View file

@ -195,96 +195,6 @@ class JsonKVStorage(BaseKVStorage):
if any_deleted: if any_deleted:
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.final_namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
try:
async with self._storage_lock:
keys_to_delete = []
modes_set = set(modes) # Convert to set for efficient lookup
for key in list(self._data.keys()):
# Parse flattened cache key: mode:cache_type:hash
parts = key.split(":", 2)
if len(parts) == 3 and parts[0] in modes_set:
keys_to_delete.append(key)
# Batch delete
for key in keys_to_delete:
self._data.pop(key, None)
if keys_to_delete:
await set_all_update_flags(self.final_namespace)
logger.info(
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
)
return True
except Exception as e:
logger.error(f"Error dropping cache by modes: {e}")
return False
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
# """Delete specific cache records from storage by chunk IDs
# Importance notes for in-memory storage:
# 1. Changes will be persisted to disk during the next index_done_callback
# 2. update flags to notify other processes that data persistence is needed
# Args:
# chunk_ids (list[str]): List of chunk IDs to be dropped from storage
# Returns:
# True: if the cache drop successfully
# False: if the cache drop failed
# """
# if not chunk_ids:
# return False
# try:
# async with self._storage_lock:
# # Iterate through all cache modes to find entries with matching chunk_ids
# for mode_key, mode_data in list(self._data.items()):
# if isinstance(mode_data, dict):
# # Check each cached entry in this mode
# for cache_key, cache_entry in list(mode_data.items()):
# if (
# isinstance(cache_entry, dict)
# and cache_entry.get("chunk_id") in chunk_ids
# ):
# # Remove this cache entry
# del mode_data[cache_key]
# logger.debug(
# f"Removed cache entry {cache_key} for chunk {cache_entry.get('chunk_id')}"
# )
# # If the mode is now empty, remove it entirely
# if not mode_data:
# del self._data[mode_key]
# # Set update flags to notify persistence is needed
# await set_all_update_flags(self.final_namespace)
# logger.info(f"Cleared cache for {len(chunk_ids)} chunk IDs")
# return True
# except Exception as e:
# logger.error(f"Error clearing cache by chunk IDs: {e}")
# return False
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources """Drop all data from storage and clean up resources
This action will persistent the data to disk immediately. This action will persistent the data to disk immediately.

View file

@ -232,28 +232,6 @@ class MongoKVStorage(BaseKVStorage):
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error deleting documents from {self.namespace}: {e}") logger.error(f"Error deleting documents from {self.namespace}: {e}")
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
bool: True if successful, False otherwise
"""
if not modes:
return False
try:
# Build regex pattern to match flattened key format: mode:cache_type:hash
pattern = f"^({'|'.join(modes)}):"
result = await self._data.delete_many({"_id": {"$regex": pattern}})
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
return True
except Exception as e:
logger.error(f"Error deleting cache by modes {modes}: {e}")
return False
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection. """Drop the storage by removing all documents in the collection.

View file

@ -224,15 +224,15 @@ class PostgreSQLDB:
): ):
pass pass
async def _migrate_llm_cache_add_columns(self): async def _migrate_llm_cache_schema(self):
"""Add chunk_id and cache_type columns to LIGHTRAG_LLM_CACHE table if they don't exist""" """Migrate LLM cache schema: add new columns and remove deprecated mode field"""
try: try:
# Check if both columns exist # Check if all columns exist
check_columns_sql = """ check_columns_sql = """
SELECT column_name SELECT column_name
FROM information_schema.columns FROM information_schema.columns
WHERE table_name = 'lightrag_llm_cache' WHERE table_name = 'lightrag_llm_cache'
AND column_name IN ('chunk_id', 'cache_type') AND column_name IN ('chunk_id', 'cache_type', 'queryparam', 'mode')
""" """
existing_columns = await self.query(check_columns_sql, multirows=True) existing_columns = await self.query(check_columns_sql, multirows=True)
@ -289,8 +289,58 @@ class PostgreSQLDB:
"cache_type column already exists in LIGHTRAG_LLM_CACHE table" "cache_type column already exists in LIGHTRAG_LLM_CACHE table"
) )
# Add missing queryparam column
if "queryparam" not in existing_column_names:
logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table")
add_queryparam_sql = """
ALTER TABLE LIGHTRAG_LLM_CACHE
ADD COLUMN queryparam JSONB NULL
"""
await self.execute(add_queryparam_sql)
logger.info(
"Successfully added queryparam column to LIGHTRAG_LLM_CACHE table"
)
else:
logger.info(
"queryparam column already exists in LIGHTRAG_LLM_CACHE table"
)
# Remove deprecated mode field if it exists
if "mode" in existing_column_names:
logger.info(
"Removing deprecated mode column from LIGHTRAG_LLM_CACHE table"
)
# First, drop the primary key constraint that includes mode
drop_pk_sql = """
ALTER TABLE LIGHTRAG_LLM_CACHE
DROP CONSTRAINT IF EXISTS LIGHTRAG_LLM_CACHE_PK
"""
await self.execute(drop_pk_sql)
logger.info("Dropped old primary key constraint")
# Drop the mode column
drop_mode_sql = """
ALTER TABLE LIGHTRAG_LLM_CACHE
DROP COLUMN mode
"""
await self.execute(drop_mode_sql)
logger.info(
"Successfully removed mode column from LIGHTRAG_LLM_CACHE table"
)
# Create new primary key constraint without mode
add_pk_sql = """
ALTER TABLE LIGHTRAG_LLM_CACHE
ADD CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
"""
await self.execute(add_pk_sql)
logger.info("Created new primary key constraint (workspace, id)")
else:
logger.info("mode column does not exist in LIGHTRAG_LLM_CACHE table")
except Exception as e: except Exception as e:
logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}") logger.warning(f"Failed to migrate LLM cache schema: {e}")
async def _migrate_timestamp_columns(self): async def _migrate_timestamp_columns(self):
"""Migrate timestamp columns in tables to witimezone-free types, assuming original data is in UTC time""" """Migrate timestamp columns in tables to witimezone-free types, assuming original data is in UTC time"""
@ -856,11 +906,11 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL, Failed to migrate timestamp columns: {e}") logger.error(f"PostgreSQL, Failed to migrate timestamp columns: {e}")
# Don't throw an exception, allow the initialization process to continue # Don't throw an exception, allow the initialization process to continue
# Migrate LLM cache table to add chunk_id and cache_type columns if needed # Migrate LLM cache schema: add new columns and remove deprecated mode field
try: try:
await self._migrate_llm_cache_add_columns() await self._migrate_llm_cache_schema()
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL, Failed to migrate LLM cache columns: {e}") logger.error(f"PostgreSQL, Failed to migrate LLM cache schema: {e}")
# Don't throw an exception, allow the initialization process to continue # Don't throw an exception, allow the initialization process to continue
# Finally, attempt to migrate old doc chunks data if needed # Finally, attempt to migrate old doc chunks data if needed
@ -1379,14 +1429,21 @@ class PGKVStorage(BaseKVStorage):
): ):
create_time = response.get("create_time", 0) create_time = response.get("create_time", 0)
update_time = response.get("update_time", 0) update_time = response.get("update_time", 0)
# Map field names and add cache_type for compatibility # Parse queryparam JSON string back to dict
queryparam = response.get("queryparam")
if isinstance(queryparam, str):
try:
queryparam = json.loads(queryparam)
except json.JSONDecodeError:
queryparam = None
# Map field names for compatibility (mode field removed)
response = { response = {
**response, **response,
"return": response.get("return_value", ""), "return": response.get("return_value", ""),
"cache_type": response.get("cache_type"), "cache_type": response.get("cache_type"),
"original_prompt": response.get("original_prompt", ""), "original_prompt": response.get("original_prompt", ""),
"chunk_id": response.get("chunk_id"), "chunk_id": response.get("chunk_id"),
"mode": response.get("mode", "default"), "queryparam": queryparam,
"create_time": create_time, "create_time": create_time,
"update_time": create_time if update_time == 0 else update_time, "update_time": create_time if update_time == 0 else update_time,
} }
@ -1455,14 +1512,21 @@ class PGKVStorage(BaseKVStorage):
for row in results: for row in results:
create_time = row.get("create_time", 0) create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0) update_time = row.get("update_time", 0)
# Map field names and add cache_type for compatibility # Parse queryparam JSON string back to dict
queryparam = row.get("queryparam")
if isinstance(queryparam, str):
try:
queryparam = json.loads(queryparam)
except json.JSONDecodeError:
queryparam = None
# Map field names for compatibility (mode field removed)
processed_row = { processed_row = {
**row, **row,
"return": row.get("return_value", ""), "return": row.get("return_value", ""),
"cache_type": row.get("cache_type"), "cache_type": row.get("cache_type"),
"original_prompt": row.get("original_prompt", ""), "original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"), "chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"), "queryparam": queryparam,
"create_time": create_time, "create_time": create_time,
"update_time": create_time if update_time == 0 else update_time, "update_time": create_time if update_time == 0 else update_time,
} }
@ -1565,11 +1629,13 @@ class PGKVStorage(BaseKVStorage):
"id": k, # Use flattened key as id "id": k, # Use flattened key as id
"original_prompt": v["original_prompt"], "original_prompt": v["original_prompt"],
"return_value": v["return"], "return_value": v["return"],
"mode": v.get("mode", "default"), # Get mode from data
"chunk_id": v.get("chunk_id"), "chunk_id": v.get("chunk_id"),
"cache_type": v.get( "cache_type": v.get(
"cache_type", "extract" "cache_type", "extract"
), # Get cache_type from data ), # Get cache_type from data
"queryparam": json.dumps(v.get("queryparam"))
if v.get("queryparam")
else None,
} }
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
@ -1635,39 +1701,6 @@ class PGKVStorage(BaseKVStorage):
except Exception as e: except Exception as e:
logger.error(f"Error while deleting records from {self.namespace}: {e}") logger.error(f"Error while deleting records from {self.namespace}: {e}")
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
bool: True if successful, False otherwise
"""
if not modes:
return False
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return False
if table_name != "LIGHTRAG_LLM_CACHE":
return False
sql = f"""
DELETE FROM {table_name}
WHERE workspace = $1 AND mode = ANY($2)
"""
params = {"workspace": self.db.workspace, "modes": modes}
logger.info(f"Deleting cache by modes: {modes}")
await self.db.execute(sql, params)
return True
except Exception as e:
logger.error(f"Error deleting cache by modes {modes}: {e}")
return False
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
try: try:
@ -4049,14 +4082,14 @@ TABLES = {
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
workspace varchar(255) NOT NULL, workspace varchar(255) NOT NULL,
id varchar(255) NOT NULL, id varchar(255) NOT NULL,
mode varchar(32) NOT NULL,
original_prompt TEXT, original_prompt TEXT,
return_value TEXT, return_value TEXT,
chunk_id VARCHAR(255) NULL, chunk_id VARCHAR(255) NULL,
cache_type VARCHAR(32), cache_type VARCHAR(32),
queryparam JSONB NULL,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
)""" )"""
}, },
"LIGHTRAG_DOC_STATUS": { "LIGHTRAG_DOC_STATUS": {
@ -4114,14 +4147,11 @@ SQL_TEMPLATES = {
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""", """,
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
""", """,
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
""", """,
@ -4132,7 +4162,7 @@ SQL_TEMPLATES = {
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""", """,
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
@ -4163,14 +4193,14 @@ SQL_TEMPLATES = {
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET content = $2, update_time = CURRENT_TIMESTAMP SET content = $2, update_time = CURRENT_TIMESTAMP
""", """,
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type) "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,chunk_id,cache_type,queryparam)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (workspace,mode,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET original_prompt = EXCLUDED.original_prompt, SET original_prompt = EXCLUDED.original_prompt,
return_value=EXCLUDED.return_value, return_value=EXCLUDED.return_value,
mode=EXCLUDED.mode,
chunk_id=EXCLUDED.chunk_id, chunk_id=EXCLUDED.chunk_id,
cache_type=EXCLUDED.cache_type, cache_type=EXCLUDED.cache_type,
queryparam=EXCLUDED.queryparam,
update_time = CURRENT_TIMESTAMP update_time = CURRENT_TIMESTAMP
""", """,
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,

View file

@ -397,66 +397,6 @@ class RedisKVStorage(BaseKVStorage):
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
) )
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
try:
async with self._get_redis_connection() as redis:
keys_to_delete = []
# Find matching keys for each mode using SCAN
for mode in modes:
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
pattern = f"{self.namespace}:{mode}:*"
cursor = 0
mode_keys = []
while True:
cursor, keys = await redis.scan(
cursor, match=pattern, count=1000
)
if keys:
mode_keys.extend(keys)
if cursor == 0:
break
keys_to_delete.extend(mode_keys)
logger.info(
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
)
if keys_to_delete:
# Batch delete
pipe = redis.pipeline()
for key in keys_to_delete:
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Dropped {deleted_count} cache entries for modes: {modes}"
)
else:
logger.warning(f"No cache entries found for modes: {modes}")
return True
except Exception as e:
logger.error(f"Error dropping cache by modes in Redis: {e}")
return False
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all keys under the current namespace. """Drop the storage by removing all keys under the current namespace.

View file

@ -1915,58 +1915,35 @@ class LightRAG:
async def _query_done(self): async def _query_done(self):
await self.llm_response_cache.index_done_callback() await self.llm_response_cache.index_done_callback()
async def aclear_cache(self, modes: list[str] | None = None) -> None: async def aclear_cache(self) -> None:
"""Clear cache data from the LLM response cache storage. """Clear all cache data from the LLM response cache storage.
Args: This method clears all cached LLM responses regardless of mode.
modes (list[str] | None): Modes of cache to clear. Options: ["default", "naive", "local", "global", "hybrid", "mix"].
"default" represents extraction cache.
If None, clears all cache.
Example: Example:
# Clear all cache # Clear all cache
await rag.aclear_cache() await rag.aclear_cache()
# Clear local mode cache
await rag.aclear_cache(modes=["local"])
# Clear extraction cache
await rag.aclear_cache(modes=["default"])
""" """
if not self.llm_response_cache: if not self.llm_response_cache:
logger.warning("No cache storage configured") logger.warning("No cache storage configured")
return return
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
# Validate input
if modes and not all(mode in valid_modes for mode in modes):
raise ValueError(f"Invalid mode. Valid modes are: {valid_modes}")
try: try:
# Reset the cache storage for specified mode # Clear all cache using drop method
if modes: success = await self.llm_response_cache.drop()
success = await self.llm_response_cache.drop_cache_by_modes(modes) if success:
if success: logger.info("Cleared all cache")
logger.info(f"Cleared cache for modes: {modes}")
else:
logger.warning(f"Failed to clear cache for modes: {modes}")
else: else:
# Clear all modes logger.warning("Failed to clear all cache")
success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
if success:
logger.info("Cleared all cache")
else:
logger.warning("Failed to clear all cache")
await self.llm_response_cache.index_done_callback() await self.llm_response_cache.index_done_callback()
except Exception as e: except Exception as e:
logger.error(f"Error while clearing cache: {e}") logger.error(f"Error while clearing cache: {e}")
def clear_cache(self, modes: list[str] | None = None) -> None: def clear_cache(self) -> None:
"""Synchronous version of aclear_cache.""" """Synchronous version of aclear_cache."""
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes)) return always_get_an_event_loop().run_until_complete(self.aclear_cache())
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus

View file

@ -1727,7 +1727,20 @@ async def kg_query(
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
# Handle cache # Handle cache
args_hash = compute_args_hash(query_param.mode, query) args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
) )
@ -1826,18 +1839,29 @@ async def kg_query(
) )
if hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache # Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(
args_hash=args_hash, args_hash=args_hash,
content=response, content=response,
prompt=query, prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=query_param.mode, mode=query_param.mode,
cache_type="query", cache_type="query",
queryparam=queryparam_dict,
), ),
) )
@ -1889,7 +1913,20 @@ async def extract_keywords_only(
""" """
# 1. Handle cache if needed - add cache type for keywords # 1. Handle cache if needed - add cache type for keywords
args_hash = compute_args_hash(param.mode, text) args_hash = compute_args_hash(
param.mode,
text,
param.response_type,
param.top_k,
param.chunk_top_k,
param.max_entity_tokens,
param.max_relation_tokens,
param.max_total_tokens,
param.hl_keywords or [],
param.ll_keywords or [],
param.user_prompt or "",
param.enable_rerank,
)
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, text, param.mode, cache_type="keywords" hashing_kv, args_hash, text, param.mode, cache_type="keywords"
) )
@ -1966,17 +2003,29 @@ async def extract_keywords_only(
"low_level_keywords": ll_keywords, "low_level_keywords": ll_keywords,
} }
if hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache with query parameters
queryparam_dict = {
"mode": param.mode,
"response_type": param.response_type,
"top_k": param.top_k,
"chunk_top_k": param.chunk_top_k,
"max_entity_tokens": param.max_entity_tokens,
"max_relation_tokens": param.max_relation_tokens,
"max_total_tokens": param.max_total_tokens,
"hl_keywords": param.hl_keywords or [],
"ll_keywords": param.ll_keywords or [],
"user_prompt": param.user_prompt or "",
"enable_rerank": param.enable_rerank,
}
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(
args_hash=args_hash, args_hash=args_hash,
content=json.dumps(cache_data), content=json.dumps(cache_data),
prompt=text, prompt=text,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=param.mode, mode=param.mode,
cache_type="keywords", cache_type="keywords",
queryparam=queryparam_dict,
), ),
) )
@ -2951,7 +3000,20 @@ async def naive_query(
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
# Handle cache # Handle cache
args_hash = compute_args_hash(query_param.mode, query) args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
) )
@ -3098,18 +3160,29 @@ async def naive_query(
) )
if hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache # Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(
args_hash=args_hash, args_hash=args_hash,
content=response, content=response,
prompt=query, prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=query_param.mode, mode=query_param.mode,
cache_type="query", cache_type="query",
queryparam=queryparam_dict,
), ),
) )
@ -3231,9 +3304,6 @@ async def kg_query_with_keywords(
args_hash=args_hash, args_hash=args_hash,
content=response, content=response,
prompt=query, prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=query_param.mode, mode=query_param.mode,
cache_type="query", cache_type="query",
), ),

View file

@ -756,40 +756,6 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2) return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
embedding = np.array(embedding)
# Calculate min/max values for reconstruction
min_val = embedding.min()
max_val = embedding.max()
if min_val == max_val:
# handle constant vector
quantized = np.zeros_like(embedding, dtype=np.uint8)
return quantized, min_val, max_val
# Quantize to 0-255 range
scale = (2**bits - 1) / (max_val - min_val)
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
return quantized, min_val, max_val
def dequantize_embedding(
quantized: np.ndarray, min_val: float, max_val: float, bits=8
) -> np.ndarray:
"""Restore quantized embedding"""
if min_val == max_val:
# handle constant vector
return np.full_like(quantized, min_val, dtype=np.float32)
scale = (max_val - min_val) / (2**bits - 1)
return (quantized * scale + min_val).astype(np.float32)
async def handle_cache( async def handle_cache(
hashing_kv, hashing_kv,
args_hash, args_hash,
@ -824,12 +790,10 @@ class CacheData:
args_hash: str args_hash: str
content: str content: str
prompt: str prompt: str
quantized: np.ndarray | None = None
min_val: float | None = None
max_val: float | None = None
mode: str = "default" mode: str = "default"
cache_type: str = "query" cache_type: str = "query"
chunk_id: str | None = None chunk_id: str | None = None
queryparam: dict | None = None
async def save_to_cache(hashing_kv, cache_data: CacheData): async def save_to_cache(hashing_kv, cache_data: CacheData):
@ -866,15 +830,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
"return": cache_data.content, "return": cache_data.content,
"cache_type": cache_data.cache_type, "cache_type": cache_data.cache_type,
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt, "original_prompt": cache_data.prompt,
"queryparam": cache_data.queryparam
if cache_data.queryparam is not None
else None,
} }
logger.info(f" == LLM cache == saving: {flattened_key}") logger.info(f" == LLM cache == saving: {flattened_key}")

View file

@ -586,11 +586,11 @@ export const clearDocuments = async (): Promise<DocActionResponse> => {
return response.data return response.data
} }
export const clearCache = async (modes?: string[]): Promise<{ export const clearCache = async (): Promise<{
status: 'success' | 'fail' status: 'success' | 'fail'
message: string message: string
}> => { }> => {
const response = await axiosInstance.post('/documents/clear_cache', { modes }) const response = await axiosInstance.post('/documents/clear_cache', {})
return response.data return response.data
} }