fix(postgres): allow vchordrq.epsilon config when probes is empty

Previously, configure_vchordrq would fail silently when probes was empty
(the default), preventing epsilon from being configured. Now each parameter
is handled independently with conditional execution, and configuration
errors fail-fast instead of being swallowed.

This fixes the documented epsilon setting being impossible to use in the
default configuration.

(cherry picked from commit 3096f844fb)
This commit is contained in:
yangdx 2025-11-18 21:58:36 +08:00 committed by Raphaël MANSUY
parent 5bd1320a1d
commit 0ac858d3e2

View file

@ -33,7 +33,6 @@ from ..base import (
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
import pipmaster as pm
@ -78,6 +77,9 @@ class PostgreSQLDB:
self.hnsw_m = config.get("hnsw_m")
self.hnsw_ef = config.get("hnsw_ef")
self.ivfflat_lists = config.get("ivfflat_lists")
self.vchordrq_build_options = config.get("vchordrq_build_options")
self.vchordrq_probes = config.get("vchordrq_probes")
self.vchordrq_epsilon = config.get("vchordrq_epsilon")
# Server settings
self.server_settings = config.get("server_settings")
@ -85,24 +87,11 @@ class PostgreSQLDB:
# Statement LRU cache size (keep as-is, allow None for optional configuration)
self.statement_cache_size = config.get("statement_cache_size")
# Connection retry configuration
self.connection_retry_attempts = max(
1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3)))
)
self.connection_retry_backoff = max(
0.1,
min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))),
)
self.connection_retry_backoff_max = max(
self.connection_retry_backoff,
min(
60.0,
float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)),
),
)
self.pool_close_timeout = max(
1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0)))
)
if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database")
# Guard concurrent pool resets
self._pool_reconnect_lock = asyncio.Lock()
self._transient_exceptions = (
asyncio.TimeoutError,
@ -117,12 +106,14 @@ class PostgreSQLDB:
asyncpg.exceptions.ConnectionFailureError,
)
# Guard concurrent pool resets
self._pool_reconnect_lock = asyncio.Lock()
if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database")
# Connection retry configuration
self.connection_retry_attempts = config["connection_retry_attempts"]
self.connection_retry_backoff = config["connection_retry_backoff"]
self.connection_retry_backoff_max = max(
self.connection_retry_backoff,
config["connection_retry_backoff_max"],
)
self.pool_close_timeout = config["pool_close_timeout"]
logger.info(
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
self.connection_retry_attempts,
@ -215,9 +206,7 @@ class PostgreSQLDB:
# Only add statement_cache_size if it's configured
if self.statement_cache_size is not None:
connection_params["statement_cache_size"] = int(
self.statement_cache_size
)
connection_params["statement_cache_size"] = int(self.statement_cache_size)
logger.info(
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
)
@ -376,7 +365,8 @@ class PostgreSQLDB:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
if self.vector_index_type == "VCHORDRQ":
await self.configure_vchordrq(connection)
return await operation(connection)
@staticmethod
@ -422,6 +412,29 @@ class PostgreSQLDB:
):
pass
async def configure_vchordrq(self, connection: asyncpg.Connection) -> None:
"""Configure VCHORDRQ extension for vector similarity search.
Raises:
asyncpg.exceptions.UndefinedObjectError: If VCHORDRQ extension is not installed
asyncpg.exceptions.InvalidParameterValueError: If parameter value is invalid
Note:
This method does not catch exceptions. Configuration errors will fail-fast,
while transient connection errors will be retried by _run_with_retry.
"""
# Handle probes parameter - only set if non-empty value is provided
if self.vchordrq_probes and str(self.vchordrq_probes).strip():
await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'")
logger.debug(f"PostgreSQL, VCHORDRQ probes set to: {self.vchordrq_probes}")
# Handle epsilon parameter independently - check for None to allow 0.0 as valid value
if self.vchordrq_epsilon is not None:
await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}")
logger.debug(
f"PostgreSQL, VCHORDRQ epsilon set to: {self.vchordrq_epsilon}"
)
async def _migrate_llm_cache_schema(self):
"""Migrate LLM cache schema: add new columns and remove deprecated mode field"""
try:
@ -1156,19 +1169,12 @@ class PostgreSQLDB:
f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}"
)
try:
if self.vector_index_type == "HNSW":
await self._create_hnsw_vector_indexes()
elif self.vector_index_type == "IVFFLAT":
await self._create_ivfflat_vector_indexes()
elif self.vector_index_type == "FLAT":
logger.warning(
"FLAT index type is not supported by pgvector. Skipping vector index creation. "
"Please use 'HNSW' or 'IVFFLAT' instead."
)
if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]:
await self._create_vector_indexes()
else:
logger.warning(
"Doesn't support this vector index type: {self.vector_index_type}. "
"Supported types: HNSW, IVFFLAT"
"Supported types: HNSW, IVFFLAT, VCHORDRQ"
)
except Exception as e:
logger.error(
@ -1375,21 +1381,39 @@ class PostgreSQLDB:
except Exception as e:
logger.warning(f"Failed to create index {index['name']}: {e}")
async def _create_hnsw_vector_indexes(self):
async def _create_vector_indexes(self):
vdb_tables = [
"LIGHTRAG_VDB_CHUNKS",
"LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION",
]
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
create_sql = {
"HNSW": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING hnsw (content_vector vector_cosine_ops)
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
""",
"IVFFLAT": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
""",
"VCHORDRQ": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING vchordrq (content_vector vector_cosine_ops)
{f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''}
""",
}
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
for k in vdb_tables:
vector_index_name = f"idx_{k.lower()}_hnsw_cosine"
vector_index_name = (
f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
)
check_vector_index_sql = f"""
SELECT 1 FROM pg_indexes
WHERE indexname = '{vector_index_name}'
AND tablename = '{k.lower()}'
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
"""
try:
vector_index_exists = await self.query(check_vector_index_sql)
@ -1398,64 +1422,24 @@ class PostgreSQLDB:
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}")
create_vector_index_sql = f"""
CREATE INDEX {vector_index_name}
ON {k} USING hnsw (content_vector vector_cosine_ops)
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
"""
logger.info(f"Creating hnsw index {vector_index_name} on table {k}")
await self.execute(create_vector_index_sql)
logger.info(
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
)
await self.execute(
create_sql[self.vector_index_type].format(
vector_index_name=vector_index_name, k=k
)
)
logger.info(
f"Successfully created vector index {vector_index_name} on table {k}"
)
else:
logger.info(
f"HNSW vector index {vector_index_name} already exists on table {k}"
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
)
except Exception as e:
logger.error(f"Failed to create vector index on table {k}, Got: {e}")
async def _create_ivfflat_vector_indexes(self):
vdb_tables = [
"LIGHTRAG_VDB_CHUNKS",
"LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION",
]
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
for k in vdb_tables:
index_name = f"idx_{k.lower()}_ivfflat_cosine"
check_index_sql = f"""
SELECT 1 FROM pg_indexes
WHERE indexname = '{index_name}' AND tablename = '{k.lower()}'
"""
try:
exists = await self.query(check_index_sql)
if not exists:
# Only set vector dimension when index doesn't exist
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}")
create_sql = f"""
CREATE INDEX {index_name}
ON {k} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
"""
logger.info(f"Creating ivfflat index {index_name} on table {k}")
await self.execute(create_sql)
logger.info(
f"Successfully created ivfflat index {index_name} on table {k}"
)
else:
logger.info(
f"Ivfflat vector index {index_name} already exists on table {k}"
)
except Exception as e:
logger.error(f"Failed to create ivfflat index on {k}: {e}")
async def query(
self,
sql: str,
@ -1610,6 +1594,20 @@ class ClientManager:
config.get("postgres", "ivfflat_lists", fallback="100"),
)
),
"vchordrq_build_options": os.environ.get(
"POSTGRES_VCHORDRQ_BUILD_OPTIONS",
config.get("postgres", "vchordrq_build_options", fallback=""),
),
"vchordrq_probes": os.environ.get(
"POSTGRES_VCHORDRQ_PROBES",
config.get("postgres", "vchordrq_probes", fallback=""),
),
"vchordrq_epsilon": float(
os.environ.get(
"POSTGRES_VCHORDRQ_EPSILON",
config.get("postgres", "vchordrq_epsilon", fallback="1.9"),
)
),
# Server settings for Supabase
"server_settings": os.environ.get(
"POSTGRES_SERVER_SETTINGS",
@ -1619,6 +1617,49 @@ class ClientManager:
"POSTGRES_STATEMENT_CACHE_SIZE",
config.get("postgres", "statement_cache_size", fallback=None),
),
# Connection retry configuration
"connection_retry_attempts": min(
10,
int(
os.environ.get(
"POSTGRES_CONNECTION_RETRIES",
config.get("postgres", "connection_retries", fallback=3),
)
),
),
"connection_retry_backoff": min(
5.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF",
config.get(
"postgres", "connection_retry_backoff", fallback=0.5
),
)
),
),
"connection_retry_backoff_max": min(
60.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
config.get(
"postgres",
"connection_retry_backoff_max",
fallback=5.0,
),
)
),
),
"pool_close_timeout": min(
30.0,
float(
os.environ.get(
"POSTGRES_POOL_CLOSE_TIMEOUT",
config.get("postgres", "pool_close_timeout", fallback=5.0),
)
),
),
}
@classmethod
@ -1679,113 +1720,6 @@ class PGKVStorage(BaseKVStorage):
self.db = None
################ QUERY METHODS ################
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for get_all: {self.namespace}"
)
return {}
sql = f"SELECT * FROM {table_name} WHERE workspace=$1"
params = {"workspace": self.workspace}
try:
results = await self.db.query(sql, list(params.values()), multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
processed_results = {}
for row in results:
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
# Map field names and add cache_type for compatibility
processed_row = {
**row,
"return": row.get("return_value", ""),
"cache_type": row.get("original_prompt", "unknow"),
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
processed_results[row["id"]] = processed_row
return processed_results
# For text_chunks namespace, parse llm_cache_list JSON string back to list
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
processed_results = {}
for row in results:
llm_cache_list = row.get("llm_cache_list", [])
if isinstance(llm_cache_list, str):
try:
llm_cache_list = json.loads(llm_cache_list)
except json.JSONDecodeError:
llm_cache_list = []
row["llm_cache_list"] = llm_cache_list
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
row["create_time"] = create_time
row["update_time"] = (
create_time if update_time == 0 else update_time
)
processed_results[row["id"]] = row
return processed_results
# For FULL_ENTITIES namespace, parse entity_names JSON string back to list
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
processed_results = {}
for row in results:
entity_names = row.get("entity_names", [])
if isinstance(entity_names, str):
try:
entity_names = json.loads(entity_names)
except json.JSONDecodeError:
entity_names = []
row["entity_names"] = entity_names
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
row["create_time"] = create_time
row["update_time"] = (
create_time if update_time == 0 else update_time
)
processed_results[row["id"]] = row
return processed_results
# For FULL_RELATIONS namespace, parse relation_pairs JSON string back to list
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
processed_results = {}
for row in results:
relation_pairs = row.get("relation_pairs", [])
if isinstance(relation_pairs, str):
try:
relation_pairs = json.loads(relation_pairs)
except json.JSONDecodeError:
relation_pairs = []
row["relation_pairs"] = relation_pairs
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
row["create_time"] = create_time
row["update_time"] = (
create_time if update_time == 0 else update_time
)
processed_results[row["id"]] = row
return processed_results
# For other namespaces, return as-is
return {row["id"]: row for row in results}
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving all data from {self.namespace}: {e}"
)
return {}
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
@ -1861,6 +1795,38 @@ class PGKVStorage(BaseKVStorage):
response["create_time"] = create_time
response["update_time"] = create_time if update_time == 0 else update_time
# Special handling for ENTITY_CHUNKS namespace
if response and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
# Parse chunk_ids JSON string back to list
chunk_ids = response.get("chunk_ids", [])
if isinstance(chunk_ids, str):
try:
chunk_ids = json.loads(chunk_ids)
except json.JSONDecodeError:
chunk_ids = []
response["chunk_ids"] = chunk_ids
create_time = response.get("create_time", 0)
update_time = response.get("update_time", 0)
response["create_time"] = create_time
response["update_time"] = create_time if update_time == 0 else update_time
# Special handling for RELATION_CHUNKS namespace
if response and is_namespace(
self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS
):
# Parse chunk_ids JSON string back to list
chunk_ids = response.get("chunk_ids", [])
if isinstance(chunk_ids, str):
try:
chunk_ids = json.loads(chunk_ids)
except json.JSONDecodeError:
chunk_ids = []
response["chunk_ids"] = chunk_ids
create_time = response.get("create_time", 0)
update_time = response.get("update_time", 0)
response["create_time"] = create_time
response["update_time"] = create_time if update_time == 0 else update_time
return response if response else None
# Query by id
@ -1868,7 +1834,7 @@ class PGKVStorage(BaseKVStorage):
"""Get data by ids"""
if not ids:
return []
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, list(params.values()), multirows=True)
@ -1969,13 +1935,45 @@ class PGKVStorage(BaseKVStorage):
result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time
# Special handling for ENTITY_CHUNKS namespace
if results and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
for result in results:
# Parse chunk_ids JSON string back to list
chunk_ids = result.get("chunk_ids", [])
if isinstance(chunk_ids, str):
try:
chunk_ids = json.loads(chunk_ids)
except json.JSONDecodeError:
chunk_ids = []
result["chunk_ids"] = chunk_ids
create_time = result.get("create_time", 0)
update_time = result.get("update_time", 0)
result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time
# Special handling for RELATION_CHUNKS namespace
if results and is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
for result in results:
# Parse chunk_ids JSON string back to list
chunk_ids = result.get("chunk_ids", [])
if isinstance(chunk_ids, str):
try:
chunk_ids = json.loads(chunk_ids)
except json.JSONDecodeError:
chunk_ids = []
result["chunk_ids"] = chunk_ids
create_time = result.get("create_time", 0)
update_time = result.get("update_time", 0)
result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time
return _order_results(results)
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
@ -2073,11 +2071,61 @@ class PGKVStorage(BaseKVStorage):
"update_time": current_time,
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
# Get current UTC time and convert to naive datetime for database storage
current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_entity_chunks"]
_data = {
"workspace": self.workspace,
"id": k,
"chunk_ids": json.dumps(v["chunk_ids"]),
"count": v["count"],
"create_time": current_time,
"update_time": current_time,
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
# Get current UTC time and convert to naive datetime for database storage
current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_relation_chunks"]
_data = {
"workspace": self.workspace,
"id": k,
"chunk_ids": json.dumps(v["chunk_ids"]),
"count": v["count"],
"create_time": current_time,
"update_time": current_time,
}
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self) -> None:
# PG handles persistence automatically
pass
async def is_empty(self) -> bool:
"""Check if the storage is empty for the current workspace and namespace
Returns:
bool: True if storage is empty, False otherwise
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
)
return True
sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
try:
result = await self.db.query(sql, [self.workspace])
return not result.get("has_data", False) if result else True
except Exception as e:
logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
return True
async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
@ -2559,7 +2607,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""Filter out duplicated content"""
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
@ -2993,6 +3041,28 @@ class PGDocStatusStorage(DocStatusStorage):
# PG handles persistence automatically
pass
async def is_empty(self) -> bool:
"""Check if the storage is empty for the current workspace and namespace
Returns:
bool: True if storage is empty, False otherwise
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
)
return True
sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
try:
result = await self.db.query(sql, [self.workspace])
return not result.get("has_data", False) if result else True
except Exception as e:
logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
return True
async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
@ -3510,17 +3580,13 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties"""
label = self._normalize_node_id(node_id)
result = await self.get_nodes_batch(node_ids=[label])
result = await self.get_nodes_batch(node_ids=[node_id])
if result and node_id in result:
return result[node_id]
return None
async def node_degree(self, node_id: str) -> int:
label = self._normalize_node_id(node_id)
result = await self.node_degrees_batch(node_ids=[label])
result = await self.node_degrees_batch(node_ids=[node_id])
if result and node_id in result:
return result[node_id]
@ -3533,12 +3599,11 @@ class PGGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes"""
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
if result and (src_label, tgt_label) in result:
return result[(src_label, tgt_label)]
result = await self.get_edges_batch(
[{"src": source_node_id, "tgt": target_node_id}]
)
if result and (source_node_id, target_node_id) in result:
return result[(source_node_id, target_node_id)]
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -3736,13 +3801,17 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
seen = set()
unique_ids = []
seen: set[str] = set()
unique_ids: list[str] = []
lookup: dict[str, str] = {}
requested: set[str] = set()
for nid in node_ids:
nid_norm = self._normalize_node_id(nid)
if nid_norm not in seen:
seen.add(nid_norm)
unique_ids.append(nid_norm)
if nid not in seen:
seen.add(nid)
unique_ids.append(nid)
requested.add(nid)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
# Build result dictionary
nodes_dict = {}
@ -3781,10 +3850,18 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
nodes_dict[result["node_id"]] = node_dict
node_key = result["node_id"]
original_key = lookup.get(node_key)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
nodes_dict[original_key] = node_dict
return nodes_dict
@ -3807,13 +3884,17 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
seen = set()
seen: set[str] = set()
unique_ids: list[str] = []
lookup: dict[str, str] = {}
requested: set[str] = set()
for nid in node_ids:
n = self._normalize_node_id(nid)
if n not in seen:
seen.add(n)
unique_ids.append(n)
if nid not in seen:
seen.add(nid)
unique_ids.append(nid)
requested.add(nid)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
out_degrees = {}
in_degrees = {}
@ -3865,8 +3946,16 @@ class PGGraphStorage(BaseGraphStorage):
node_id = row["node_id"]
if not node_id:
continue
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
node_key = node_id
original_key = lookup.get(node_key)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
degrees_dict = {}
for node_id in node_ids:
@ -3995,7 +4084,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
)
continue
@ -4011,7 +4100,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
@ -4116,102 +4205,6 @@ class PGGraphStorage(BaseGraphStorage):
labels.append(result["label"])
return labels
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all nodes
where the `source_id` property contains any of the specified chunk IDs.
"""
# The string representation of the list for the cypher query
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
RETURN n
$$) AS (n agtype);
"""
results = await self._query(query)
# Build result list
nodes = []
for result in results:
if result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
node_dict["id"] = node_dict["entity_id"]
nodes.append(node_dict)
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves edges from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all edges
where the `source_id` property contains any of the specified chunk IDs.
"""
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH ()-[r]-()
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
$$) AS (edge agtype, source agtype, target agtype);
"""
results = await self._query(query)
edges = []
if results:
for item in results:
edge_agtype = item["edge"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_agtype, str):
try:
edge_agtype = json.loads(edge_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
)
source_agtype = item["source"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(source_agtype, str):
try:
source_agtype = json.loads(source_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
)
target_agtype = item["target"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(target_agtype, str):
try:
target_agtype = json.loads(target_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
)
if edge_agtype and source_agtype and target_agtype:
edge_properties = edge_agtype
edge_properties["source"] = source_agtype["entity_id"]
edge_properties["target"] = target_agtype["entity_id"]
edges.append(edge_properties)
return edges
async def _bfs_subgraph(
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
@ -4757,6 +4750,8 @@ NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES",
NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS",
NameSpace.KV_STORE_ENTITY_CHUNKS: "LIGHTRAG_ENTITY_CHUNKS",
NameSpace.KV_STORE_RELATION_CHUNKS: "LIGHTRAG_RELATION_CHUNKS",
NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
@ -4897,6 +4892,28 @@ TABLES = {
CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id)
)"""
},
"LIGHTRAG_ENTITY_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_ENTITY_CHUNKS (
id VARCHAR(512),
workspace VARCHAR(255),
chunk_ids JSONB,
count INTEGER,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_ENTITY_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
"LIGHTRAG_RELATION_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_RELATION_CHUNKS (
id VARCHAR(512),
workspace VARCHAR(255),
chunk_ids JSONB,
count INTEGER,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_RELATION_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
}
@ -4954,6 +4971,26 @@ SQL_TEMPLATES = {
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_id_entity_chunks": """SELECT id, chunk_ids, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_id_relation_chunks": """SELECT id, chunk_ids, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_ids_entity_chunks": """SELECT id, chunk_ids, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_relation_chunks": """SELECT id, chunk_ids, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2)
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)
VALUES ($1, $2, $3, $4)
@ -5001,6 +5038,22 @@ SQL_TEMPLATES = {
count=EXCLUDED.count,
update_time = EXCLUDED.update_time
""",
"upsert_entity_chunks": """INSERT INTO LIGHTRAG_ENTITY_CHUNKS (workspace, id, chunk_ids, count,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (workspace,id) DO UPDATE
SET chunk_ids=EXCLUDED.chunk_ids,
count=EXCLUDED.count,
update_time = EXCLUDED.update_time
""",
"upsert_relation_chunks": """INSERT INTO LIGHTRAG_RELATION_CHUNKS (workspace, id, chunk_ids, count,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (workspace,id) DO UPDATE
SET chunk_ids=EXCLUDED.chunk_ids,
count=EXCLUDED.count,
update_time = EXCLUDED.update_time
""",
# SQL for VectorStorage
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector, file_path,