fix(postgres): allow vchordrq.epsilon config when probes is empty
Previously, configure_vchordrq would fail silently when probes was empty
(the default), preventing epsilon from being configured. Now each parameter
is handled independently with conditional execution, and configuration
errors fail-fast instead of being swallowed.
This fixes the documented epsilon setting being impossible to use in the
default configuration.
(cherry picked from commit 3096f844fb)
This commit is contained in:
parent
5bd1320a1d
commit
0ac858d3e2
1 changed files with 381 additions and 328 deletions
|
|
@ -33,7 +33,6 @@ from ..base import (
|
|||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
|
||||
|
||||
import pipmaster as pm
|
||||
|
|
@ -78,6 +77,9 @@ class PostgreSQLDB:
|
|||
self.hnsw_m = config.get("hnsw_m")
|
||||
self.hnsw_ef = config.get("hnsw_ef")
|
||||
self.ivfflat_lists = config.get("ivfflat_lists")
|
||||
self.vchordrq_build_options = config.get("vchordrq_build_options")
|
||||
self.vchordrq_probes = config.get("vchordrq_probes")
|
||||
self.vchordrq_epsilon = config.get("vchordrq_epsilon")
|
||||
|
||||
# Server settings
|
||||
self.server_settings = config.get("server_settings")
|
||||
|
|
@ -85,24 +87,11 @@ class PostgreSQLDB:
|
|||
# Statement LRU cache size (keep as-is, allow None for optional configuration)
|
||||
self.statement_cache_size = config.get("statement_cache_size")
|
||||
|
||||
# Connection retry configuration
|
||||
self.connection_retry_attempts = max(
|
||||
1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3)))
|
||||
)
|
||||
self.connection_retry_backoff = max(
|
||||
0.1,
|
||||
min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))),
|
||||
)
|
||||
self.connection_retry_backoff_max = max(
|
||||
self.connection_retry_backoff,
|
||||
min(
|
||||
60.0,
|
||||
float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)),
|
||||
),
|
||||
)
|
||||
self.pool_close_timeout = max(
|
||||
1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0)))
|
||||
)
|
||||
if self.user is None or self.password is None or self.database is None:
|
||||
raise ValueError("Missing database user, password, or database")
|
||||
|
||||
# Guard concurrent pool resets
|
||||
self._pool_reconnect_lock = asyncio.Lock()
|
||||
|
||||
self._transient_exceptions = (
|
||||
asyncio.TimeoutError,
|
||||
|
|
@ -117,12 +106,14 @@ class PostgreSQLDB:
|
|||
asyncpg.exceptions.ConnectionFailureError,
|
||||
)
|
||||
|
||||
# Guard concurrent pool resets
|
||||
self._pool_reconnect_lock = asyncio.Lock()
|
||||
|
||||
if self.user is None or self.password is None or self.database is None:
|
||||
raise ValueError("Missing database user, password, or database")
|
||||
|
||||
# Connection retry configuration
|
||||
self.connection_retry_attempts = config["connection_retry_attempts"]
|
||||
self.connection_retry_backoff = config["connection_retry_backoff"]
|
||||
self.connection_retry_backoff_max = max(
|
||||
self.connection_retry_backoff,
|
||||
config["connection_retry_backoff_max"],
|
||||
)
|
||||
self.pool_close_timeout = config["pool_close_timeout"]
|
||||
logger.info(
|
||||
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
|
||||
self.connection_retry_attempts,
|
||||
|
|
@ -215,9 +206,7 @@ class PostgreSQLDB:
|
|||
|
||||
# Only add statement_cache_size if it's configured
|
||||
if self.statement_cache_size is not None:
|
||||
connection_params["statement_cache_size"] = int(
|
||||
self.statement_cache_size
|
||||
)
|
||||
connection_params["statement_cache_size"] = int(self.statement_cache_size)
|
||||
logger.info(
|
||||
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
|
||||
)
|
||||
|
|
@ -376,7 +365,8 @@ class PostgreSQLDB:
|
|||
await self.configure_age(connection, graph_name)
|
||||
elif with_age and not graph_name:
|
||||
raise ValueError("Graph name is required when with_age is True")
|
||||
|
||||
if self.vector_index_type == "VCHORDRQ":
|
||||
await self.configure_vchordrq(connection)
|
||||
return await operation(connection)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -422,6 +412,29 @@ class PostgreSQLDB:
|
|||
):
|
||||
pass
|
||||
|
||||
async def configure_vchordrq(self, connection: asyncpg.Connection) -> None:
|
||||
"""Configure VCHORDRQ extension for vector similarity search.
|
||||
|
||||
Raises:
|
||||
asyncpg.exceptions.UndefinedObjectError: If VCHORDRQ extension is not installed
|
||||
asyncpg.exceptions.InvalidParameterValueError: If parameter value is invalid
|
||||
|
||||
Note:
|
||||
This method does not catch exceptions. Configuration errors will fail-fast,
|
||||
while transient connection errors will be retried by _run_with_retry.
|
||||
"""
|
||||
# Handle probes parameter - only set if non-empty value is provided
|
||||
if self.vchordrq_probes and str(self.vchordrq_probes).strip():
|
||||
await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'")
|
||||
logger.debug(f"PostgreSQL, VCHORDRQ probes set to: {self.vchordrq_probes}")
|
||||
|
||||
# Handle epsilon parameter independently - check for None to allow 0.0 as valid value
|
||||
if self.vchordrq_epsilon is not None:
|
||||
await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}")
|
||||
logger.debug(
|
||||
f"PostgreSQL, VCHORDRQ epsilon set to: {self.vchordrq_epsilon}"
|
||||
)
|
||||
|
||||
async def _migrate_llm_cache_schema(self):
|
||||
"""Migrate LLM cache schema: add new columns and remove deprecated mode field"""
|
||||
try:
|
||||
|
|
@ -1156,19 +1169,12 @@ class PostgreSQLDB:
|
|||
f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}"
|
||||
)
|
||||
try:
|
||||
if self.vector_index_type == "HNSW":
|
||||
await self._create_hnsw_vector_indexes()
|
||||
elif self.vector_index_type == "IVFFLAT":
|
||||
await self._create_ivfflat_vector_indexes()
|
||||
elif self.vector_index_type == "FLAT":
|
||||
logger.warning(
|
||||
"FLAT index type is not supported by pgvector. Skipping vector index creation. "
|
||||
"Please use 'HNSW' or 'IVFFLAT' instead."
|
||||
)
|
||||
if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]:
|
||||
await self._create_vector_indexes()
|
||||
else:
|
||||
logger.warning(
|
||||
"Doesn't support this vector index type: {self.vector_index_type}. "
|
||||
"Supported types: HNSW, IVFFLAT"
|
||||
"Supported types: HNSW, IVFFLAT, VCHORDRQ"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -1375,21 +1381,39 @@ class PostgreSQLDB:
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to create index {index['name']}: {e}")
|
||||
|
||||
async def _create_hnsw_vector_indexes(self):
|
||||
async def _create_vector_indexes(self):
|
||||
vdb_tables = [
|
||||
"LIGHTRAG_VDB_CHUNKS",
|
||||
"LIGHTRAG_VDB_ENTITY",
|
||||
"LIGHTRAG_VDB_RELATION",
|
||||
]
|
||||
|
||||
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
|
||||
create_sql = {
|
||||
"HNSW": f"""
|
||||
CREATE INDEX {{vector_index_name}}
|
||||
ON {{k}} USING hnsw (content_vector vector_cosine_ops)
|
||||
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
|
||||
""",
|
||||
"IVFFLAT": f"""
|
||||
CREATE INDEX {{vector_index_name}}
|
||||
ON {{k}} USING ivfflat (content_vector vector_cosine_ops)
|
||||
WITH (lists = {self.ivfflat_lists})
|
||||
""",
|
||||
"VCHORDRQ": f"""
|
||||
CREATE INDEX {{vector_index_name}}
|
||||
ON {{k}} USING vchordrq (content_vector vector_cosine_ops)
|
||||
{f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''}
|
||||
""",
|
||||
}
|
||||
|
||||
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
|
||||
for k in vdb_tables:
|
||||
vector_index_name = f"idx_{k.lower()}_hnsw_cosine"
|
||||
vector_index_name = (
|
||||
f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
|
||||
)
|
||||
check_vector_index_sql = f"""
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE indexname = '{vector_index_name}'
|
||||
AND tablename = '{k.lower()}'
|
||||
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
|
||||
"""
|
||||
try:
|
||||
vector_index_exists = await self.query(check_vector_index_sql)
|
||||
|
|
@ -1398,64 +1422,24 @@ class PostgreSQLDB:
|
|||
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
|
||||
await self.execute(alter_sql)
|
||||
logger.debug(f"Ensured vector dimension for {k}")
|
||||
|
||||
create_vector_index_sql = f"""
|
||||
CREATE INDEX {vector_index_name}
|
||||
ON {k} USING hnsw (content_vector vector_cosine_ops)
|
||||
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
|
||||
"""
|
||||
logger.info(f"Creating hnsw index {vector_index_name} on table {k}")
|
||||
await self.execute(create_vector_index_sql)
|
||||
logger.info(
|
||||
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
|
||||
)
|
||||
await self.execute(
|
||||
create_sql[self.vector_index_type].format(
|
||||
vector_index_name=vector_index_name, k=k
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created vector index {vector_index_name} on table {k}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"HNSW vector index {vector_index_name} already exists on table {k}"
|
||||
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create vector index on table {k}, Got: {e}")
|
||||
|
||||
async def _create_ivfflat_vector_indexes(self):
|
||||
vdb_tables = [
|
||||
"LIGHTRAG_VDB_CHUNKS",
|
||||
"LIGHTRAG_VDB_ENTITY",
|
||||
"LIGHTRAG_VDB_RELATION",
|
||||
]
|
||||
|
||||
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
|
||||
|
||||
for k in vdb_tables:
|
||||
index_name = f"idx_{k.lower()}_ivfflat_cosine"
|
||||
check_index_sql = f"""
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE indexname = '{index_name}' AND tablename = '{k.lower()}'
|
||||
"""
|
||||
try:
|
||||
exists = await self.query(check_index_sql)
|
||||
if not exists:
|
||||
# Only set vector dimension when index doesn't exist
|
||||
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
|
||||
await self.execute(alter_sql)
|
||||
logger.debug(f"Ensured vector dimension for {k}")
|
||||
|
||||
create_sql = f"""
|
||||
CREATE INDEX {index_name}
|
||||
ON {k} USING ivfflat (content_vector vector_cosine_ops)
|
||||
WITH (lists = {self.ivfflat_lists})
|
||||
"""
|
||||
logger.info(f"Creating ivfflat index {index_name} on table {k}")
|
||||
await self.execute(create_sql)
|
||||
logger.info(
|
||||
f"Successfully created ivfflat index {index_name} on table {k}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Ivfflat vector index {index_name} already exists on table {k}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ivfflat index on {k}: {e}")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
sql: str,
|
||||
|
|
@ -1610,6 +1594,20 @@ class ClientManager:
|
|||
config.get("postgres", "ivfflat_lists", fallback="100"),
|
||||
)
|
||||
),
|
||||
"vchordrq_build_options": os.environ.get(
|
||||
"POSTGRES_VCHORDRQ_BUILD_OPTIONS",
|
||||
config.get("postgres", "vchordrq_build_options", fallback=""),
|
||||
),
|
||||
"vchordrq_probes": os.environ.get(
|
||||
"POSTGRES_VCHORDRQ_PROBES",
|
||||
config.get("postgres", "vchordrq_probes", fallback=""),
|
||||
),
|
||||
"vchordrq_epsilon": float(
|
||||
os.environ.get(
|
||||
"POSTGRES_VCHORDRQ_EPSILON",
|
||||
config.get("postgres", "vchordrq_epsilon", fallback="1.9"),
|
||||
)
|
||||
),
|
||||
# Server settings for Supabase
|
||||
"server_settings": os.environ.get(
|
||||
"POSTGRES_SERVER_SETTINGS",
|
||||
|
|
@ -1619,6 +1617,49 @@ class ClientManager:
|
|||
"POSTGRES_STATEMENT_CACHE_SIZE",
|
||||
config.get("postgres", "statement_cache_size", fallback=None),
|
||||
),
|
||||
# Connection retry configuration
|
||||
"connection_retry_attempts": min(
|
||||
10,
|
||||
int(
|
||||
os.environ.get(
|
||||
"POSTGRES_CONNECTION_RETRIES",
|
||||
config.get("postgres", "connection_retries", fallback=3),
|
||||
)
|
||||
),
|
||||
),
|
||||
"connection_retry_backoff": min(
|
||||
5.0,
|
||||
float(
|
||||
os.environ.get(
|
||||
"POSTGRES_CONNECTION_RETRY_BACKOFF",
|
||||
config.get(
|
||||
"postgres", "connection_retry_backoff", fallback=0.5
|
||||
),
|
||||
)
|
||||
),
|
||||
),
|
||||
"connection_retry_backoff_max": min(
|
||||
60.0,
|
||||
float(
|
||||
os.environ.get(
|
||||
"POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
|
||||
config.get(
|
||||
"postgres",
|
||||
"connection_retry_backoff_max",
|
||||
fallback=5.0,
|
||||
),
|
||||
)
|
||||
),
|
||||
),
|
||||
"pool_close_timeout": min(
|
||||
30.0,
|
||||
float(
|
||||
os.environ.get(
|
||||
"POSTGRES_POOL_CLOSE_TIMEOUT",
|
||||
config.get("postgres", "pool_close_timeout", fallback=5.0),
|
||||
)
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1679,113 +1720,6 @@ class PGKVStorage(BaseKVStorage):
|
|||
self.db = None
|
||||
|
||||
################ QUERY METHODS ################
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all data from storage
|
||||
|
||||
Returns:
|
||||
Dictionary containing all stored data
|
||||
"""
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Unknown namespace for get_all: {self.namespace}"
|
||||
)
|
||||
return {}
|
||||
|
||||
sql = f"SELECT * FROM {table_name} WHERE workspace=$1"
|
||||
params = {"workspace": self.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(sql, list(params.values()), multirows=True)
|
||||
|
||||
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
processed_results = {}
|
||||
for row in results:
|
||||
create_time = row.get("create_time", 0)
|
||||
update_time = row.get("update_time", 0)
|
||||
# Map field names and add cache_type for compatibility
|
||||
processed_row = {
|
||||
**row,
|
||||
"return": row.get("return_value", ""),
|
||||
"cache_type": row.get("original_prompt", "unknow"),
|
||||
"original_prompt": row.get("original_prompt", ""),
|
||||
"chunk_id": row.get("chunk_id"),
|
||||
"mode": row.get("mode", "default"),
|
||||
"create_time": create_time,
|
||||
"update_time": create_time if update_time == 0 else update_time,
|
||||
}
|
||||
processed_results[row["id"]] = processed_row
|
||||
return processed_results
|
||||
|
||||
# For text_chunks namespace, parse llm_cache_list JSON string back to list
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
processed_results = {}
|
||||
for row in results:
|
||||
llm_cache_list = row.get("llm_cache_list", [])
|
||||
if isinstance(llm_cache_list, str):
|
||||
try:
|
||||
llm_cache_list = json.loads(llm_cache_list)
|
||||
except json.JSONDecodeError:
|
||||
llm_cache_list = []
|
||||
row["llm_cache_list"] = llm_cache_list
|
||||
create_time = row.get("create_time", 0)
|
||||
update_time = row.get("update_time", 0)
|
||||
row["create_time"] = create_time
|
||||
row["update_time"] = (
|
||||
create_time if update_time == 0 else update_time
|
||||
)
|
||||
processed_results[row["id"]] = row
|
||||
return processed_results
|
||||
|
||||
# For FULL_ENTITIES namespace, parse entity_names JSON string back to list
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
||||
processed_results = {}
|
||||
for row in results:
|
||||
entity_names = row.get("entity_names", [])
|
||||
if isinstance(entity_names, str):
|
||||
try:
|
||||
entity_names = json.loads(entity_names)
|
||||
except json.JSONDecodeError:
|
||||
entity_names = []
|
||||
row["entity_names"] = entity_names
|
||||
create_time = row.get("create_time", 0)
|
||||
update_time = row.get("update_time", 0)
|
||||
row["create_time"] = create_time
|
||||
row["update_time"] = (
|
||||
create_time if update_time == 0 else update_time
|
||||
)
|
||||
processed_results[row["id"]] = row
|
||||
return processed_results
|
||||
|
||||
# For FULL_RELATIONS namespace, parse relation_pairs JSON string back to list
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
|
||||
processed_results = {}
|
||||
for row in results:
|
||||
relation_pairs = row.get("relation_pairs", [])
|
||||
if isinstance(relation_pairs, str):
|
||||
try:
|
||||
relation_pairs = json.loads(relation_pairs)
|
||||
except json.JSONDecodeError:
|
||||
relation_pairs = []
|
||||
row["relation_pairs"] = relation_pairs
|
||||
create_time = row.get("create_time", 0)
|
||||
update_time = row.get("update_time", 0)
|
||||
row["create_time"] = create_time
|
||||
row["update_time"] = (
|
||||
create_time if update_time == 0 else update_time
|
||||
)
|
||||
processed_results[row["id"]] = row
|
||||
return processed_results
|
||||
|
||||
# For other namespaces, return as-is
|
||||
return {row["id"]: row for row in results}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error retrieving all data from {self.namespace}: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get data by id."""
|
||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
|
|
@ -1861,6 +1795,38 @@ class PGKVStorage(BaseKVStorage):
|
|||
response["create_time"] = create_time
|
||||
response["update_time"] = create_time if update_time == 0 else update_time
|
||||
|
||||
# Special handling for ENTITY_CHUNKS namespace
|
||||
if response and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
|
||||
# Parse chunk_ids JSON string back to list
|
||||
chunk_ids = response.get("chunk_ids", [])
|
||||
if isinstance(chunk_ids, str):
|
||||
try:
|
||||
chunk_ids = json.loads(chunk_ids)
|
||||
except json.JSONDecodeError:
|
||||
chunk_ids = []
|
||||
response["chunk_ids"] = chunk_ids
|
||||
create_time = response.get("create_time", 0)
|
||||
update_time = response.get("update_time", 0)
|
||||
response["create_time"] = create_time
|
||||
response["update_time"] = create_time if update_time == 0 else update_time
|
||||
|
||||
# Special handling for RELATION_CHUNKS namespace
|
||||
if response and is_namespace(
|
||||
self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS
|
||||
):
|
||||
# Parse chunk_ids JSON string back to list
|
||||
chunk_ids = response.get("chunk_ids", [])
|
||||
if isinstance(chunk_ids, str):
|
||||
try:
|
||||
chunk_ids = json.loads(chunk_ids)
|
||||
except json.JSONDecodeError:
|
||||
chunk_ids = []
|
||||
response["chunk_ids"] = chunk_ids
|
||||
create_time = response.get("create_time", 0)
|
||||
update_time = response.get("update_time", 0)
|
||||
response["create_time"] = create_time
|
||||
response["update_time"] = create_time if update_time == 0 else update_time
|
||||
|
||||
return response if response else None
|
||||
|
||||
# Query by id
|
||||
|
|
@ -1868,7 +1834,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
"""Get data by ids"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
|
||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
|
||||
params = {"workspace": self.workspace, "ids": ids}
|
||||
results = await self.db.query(sql, list(params.values()), multirows=True)
|
||||
|
|
@ -1969,13 +1935,45 @@ class PGKVStorage(BaseKVStorage):
|
|||
result["create_time"] = create_time
|
||||
result["update_time"] = create_time if update_time == 0 else update_time
|
||||
|
||||
# Special handling for ENTITY_CHUNKS namespace
|
||||
if results and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
|
||||
for result in results:
|
||||
# Parse chunk_ids JSON string back to list
|
||||
chunk_ids = result.get("chunk_ids", [])
|
||||
if isinstance(chunk_ids, str):
|
||||
try:
|
||||
chunk_ids = json.loads(chunk_ids)
|
||||
except json.JSONDecodeError:
|
||||
chunk_ids = []
|
||||
result["chunk_ids"] = chunk_ids
|
||||
create_time = result.get("create_time", 0)
|
||||
update_time = result.get("update_time", 0)
|
||||
result["create_time"] = create_time
|
||||
result["update_time"] = create_time if update_time == 0 else update_time
|
||||
|
||||
# Special handling for RELATION_CHUNKS namespace
|
||||
if results and is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
|
||||
for result in results:
|
||||
# Parse chunk_ids JSON string back to list
|
||||
chunk_ids = result.get("chunk_ids", [])
|
||||
if isinstance(chunk_ids, str):
|
||||
try:
|
||||
chunk_ids = json.loads(chunk_ids)
|
||||
except json.JSONDecodeError:
|
||||
chunk_ids = []
|
||||
result["chunk_ids"] = chunk_ids
|
||||
create_time = result.get("create_time", 0)
|
||||
update_time = result.get("update_time", 0)
|
||||
result["create_time"] = create_time
|
||||
result["update_time"] = create_time if update_time == 0 else update_time
|
||||
|
||||
return _order_results(results)
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Filter out duplicated content"""
|
||||
if not keys:
|
||||
return set()
|
||||
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
||||
params = {"workspace": self.workspace, "ids": list(keys)}
|
||||
|
|
@ -2073,11 +2071,61 @@ class PGKVStorage(BaseKVStorage):
|
|||
"update_time": current_time,
|
||||
}
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
|
||||
# Get current UTC time and convert to naive datetime for database storage
|
||||
current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
for k, v in data.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_entity_chunks"]
|
||||
_data = {
|
||||
"workspace": self.workspace,
|
||||
"id": k,
|
||||
"chunk_ids": json.dumps(v["chunk_ids"]),
|
||||
"count": v["count"],
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
}
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
|
||||
# Get current UTC time and convert to naive datetime for database storage
|
||||
current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
for k, v in data.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_relation_chunks"]
|
||||
_data = {
|
||||
"workspace": self.workspace,
|
||||
"id": k,
|
||||
"chunk_ids": json.dumps(v["chunk_ids"]),
|
||||
"count": v["count"],
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
}
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# PG handles persistence automatically
|
||||
pass
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
"""Check if the storage is empty for the current workspace and namespace
|
||||
|
||||
Returns:
|
||||
bool: True if storage is empty, False otherwise
|
||||
"""
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
|
||||
)
|
||||
return True
|
||||
|
||||
sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
|
||||
|
||||
try:
|
||||
result = await self.db.query(sql, [self.workspace])
|
||||
return not result.get("has_data", False) if result else True
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
|
||||
return True
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
|
|
@ -2559,7 +2607,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
"""Filter out duplicated content"""
|
||||
if not keys:
|
||||
return set()
|
||||
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
||||
params = {"workspace": self.workspace, "ids": list(keys)}
|
||||
|
|
@ -2993,6 +3041,28 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
# PG handles persistence automatically
|
||||
pass
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
"""Check if the storage is empty for the current workspace and namespace
|
||||
|
||||
Returns:
|
||||
bool: True if storage is empty, False otherwise
|
||||
"""
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
|
||||
)
|
||||
return True
|
||||
|
||||
sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
|
||||
|
||||
try:
|
||||
result = await self.db.query(sql, [self.workspace])
|
||||
return not result.get("has_data", False) if result else True
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
|
||||
return True
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
|
|
@ -3510,17 +3580,13 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier, return only node properties"""
|
||||
|
||||
label = self._normalize_node_id(node_id)
|
||||
|
||||
result = await self.get_nodes_batch(node_ids=[label])
|
||||
result = await self.get_nodes_batch(node_ids=[node_id])
|
||||
if result and node_id in result:
|
||||
return result[node_id]
|
||||
return None
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
label = self._normalize_node_id(node_id)
|
||||
|
||||
result = await self.node_degrees_batch(node_ids=[label])
|
||||
result = await self.node_degrees_batch(node_ids=[node_id])
|
||||
if result and node_id in result:
|
||||
return result[node_id]
|
||||
|
||||
|
|
@ -3533,12 +3599,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
"""Get edge properties between two nodes"""
|
||||
src_label = self._normalize_node_id(source_node_id)
|
||||
tgt_label = self._normalize_node_id(target_node_id)
|
||||
|
||||
result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
|
||||
if result and (src_label, tgt_label) in result:
|
||||
return result[(src_label, tgt_label)]
|
||||
result = await self.get_edges_batch(
|
||||
[{"src": source_node_id, "tgt": target_node_id}]
|
||||
)
|
||||
if result and (source_node_id, target_node_id) in result:
|
||||
return result[(source_node_id, target_node_id)]
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
|
|
@ -3736,13 +3801,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
if not node_ids:
|
||||
return {}
|
||||
|
||||
seen = set()
|
||||
unique_ids = []
|
||||
seen: set[str] = set()
|
||||
unique_ids: list[str] = []
|
||||
lookup: dict[str, str] = {}
|
||||
requested: set[str] = set()
|
||||
for nid in node_ids:
|
||||
nid_norm = self._normalize_node_id(nid)
|
||||
if nid_norm not in seen:
|
||||
seen.add(nid_norm)
|
||||
unique_ids.append(nid_norm)
|
||||
if nid not in seen:
|
||||
seen.add(nid)
|
||||
unique_ids.append(nid)
|
||||
requested.add(nid)
|
||||
lookup[nid] = nid
|
||||
lookup[self._normalize_node_id(nid)] = nid
|
||||
|
||||
# Build result dictionary
|
||||
nodes_dict = {}
|
||||
|
|
@ -3781,10 +3850,18 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse node string in batch: {node_dict}"
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
||||
)
|
||||
|
||||
nodes_dict[result["node_id"]] = node_dict
|
||||
node_key = result["node_id"]
|
||||
original_key = lookup.get(node_key)
|
||||
if original_key is None:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Node {node_key} not found in lookup map"
|
||||
)
|
||||
original_key = node_key
|
||||
if original_key in requested:
|
||||
nodes_dict[original_key] = node_dict
|
||||
|
||||
return nodes_dict
|
||||
|
||||
|
|
@ -3807,13 +3884,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
if not node_ids:
|
||||
return {}
|
||||
|
||||
seen = set()
|
||||
seen: set[str] = set()
|
||||
unique_ids: list[str] = []
|
||||
lookup: dict[str, str] = {}
|
||||
requested: set[str] = set()
|
||||
for nid in node_ids:
|
||||
n = self._normalize_node_id(nid)
|
||||
if n not in seen:
|
||||
seen.add(n)
|
||||
unique_ids.append(n)
|
||||
if nid not in seen:
|
||||
seen.add(nid)
|
||||
unique_ids.append(nid)
|
||||
requested.add(nid)
|
||||
lookup[nid] = nid
|
||||
lookup[self._normalize_node_id(nid)] = nid
|
||||
|
||||
out_degrees = {}
|
||||
in_degrees = {}
|
||||
|
|
@ -3865,8 +3946,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
node_id = row["node_id"]
|
||||
if not node_id:
|
||||
continue
|
||||
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
|
||||
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
|
||||
node_key = node_id
|
||||
original_key = lookup.get(node_key)
|
||||
if original_key is None:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Node {node_key} not found in lookup map"
|
||||
)
|
||||
original_key = node_key
|
||||
if original_key in requested:
|
||||
out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
|
||||
in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
|
||||
|
||||
degrees_dict = {}
|
||||
for node_id in node_ids:
|
||||
|
|
@ -3995,7 +4084,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
edge_props = json.loads(edge_props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge properties string: {edge_props}"
|
||||
f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -4011,7 +4100,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
edge_props = json.loads(edge_props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge properties string: {edge_props}"
|
||||
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -4116,102 +4205,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
labels.append(result["label"])
|
||||
return labels
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""
|
||||
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
|
||||
This method uses a Cypher query with UNWIND to efficiently find all nodes
|
||||
where the `source_id` property contains any of the specified chunk IDs.
|
||||
"""
|
||||
# The string representation of the list for the cypher query
|
||||
chunk_ids_str = json.dumps(chunk_ids)
|
||||
|
||||
query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND {chunk_ids_str} AS chunk_id
|
||||
MATCH (n:base)
|
||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
|
||||
RETURN n
|
||||
$$) AS (n agtype);
|
||||
"""
|
||||
results = await self._query(query)
|
||||
|
||||
# Build result list
|
||||
nodes = []
|
||||
for result in results:
|
||||
if result["n"]:
|
||||
node_dict = result["n"]["properties"]
|
||||
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(node_dict, str):
|
||||
try:
|
||||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
|
||||
)
|
||||
|
||||
node_dict["id"] = node_dict["entity_id"]
|
||||
nodes.append(node_dict)
|
||||
|
||||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""
|
||||
Retrieves edges from the graph that are associated with a given list of chunk IDs.
|
||||
This method uses a Cypher query with UNWIND to efficiently find all edges
|
||||
where the `source_id` property contains any of the specified chunk IDs.
|
||||
"""
|
||||
chunk_ids_str = json.dumps(chunk_ids)
|
||||
|
||||
query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND {chunk_ids_str} AS chunk_id
|
||||
MATCH ()-[r]-()
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
|
||||
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
|
||||
$$) AS (edge agtype, source agtype, target agtype);
|
||||
"""
|
||||
results = await self._query(query)
|
||||
edges = []
|
||||
if results:
|
||||
for item in results:
|
||||
edge_agtype = item["edge"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(edge_agtype, str):
|
||||
try:
|
||||
edge_agtype = json.loads(edge_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
|
||||
)
|
||||
|
||||
source_agtype = item["source"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(source_agtype, str):
|
||||
try:
|
||||
source_agtype = json.loads(source_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
|
||||
)
|
||||
|
||||
target_agtype = item["target"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(target_agtype, str):
|
||||
try:
|
||||
target_agtype = json.loads(target_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
|
||||
)
|
||||
|
||||
if edge_agtype and source_agtype and target_agtype:
|
||||
edge_properties = edge_agtype
|
||||
edge_properties["source"] = source_agtype["entity_id"]
|
||||
edge_properties["target"] = target_agtype["entity_id"]
|
||||
edges.append(edge_properties)
|
||||
return edges
|
||||
|
||||
async def _bfs_subgraph(
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
|
|
@ -4757,6 +4750,8 @@ NAMESPACE_TABLE_MAP = {
|
|||
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
||||
NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES",
|
||||
NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS",
|
||||
NameSpace.KV_STORE_ENTITY_CHUNKS: "LIGHTRAG_ENTITY_CHUNKS",
|
||||
NameSpace.KV_STORE_RELATION_CHUNKS: "LIGHTRAG_RELATION_CHUNKS",
|
||||
NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
|
||||
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
|
||||
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
|
||||
|
|
@ -4897,6 +4892,28 @@ TABLES = {
|
|||
CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
"LIGHTRAG_ENTITY_CHUNKS": {
|
||||
"ddl": """CREATE TABLE LIGHTRAG_ENTITY_CHUNKS (
|
||||
id VARCHAR(512),
|
||||
workspace VARCHAR(255),
|
||||
chunk_ids JSONB,
|
||||
count INTEGER,
|
||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_ENTITY_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
"LIGHTRAG_RELATION_CHUNKS": {
|
||||
"ddl": """CREATE TABLE LIGHTRAG_RELATION_CHUNKS (
|
||||
id VARCHAR(512),
|
||||
workspace VARCHAR(255),
|
||||
chunk_ids JSONB,
|
||||
count INTEGER,
|
||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_RELATION_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -4954,6 +4971,26 @@ SQL_TEMPLATES = {
|
|||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
|
||||
""",
|
||||
"get_by_id_entity_chunks": """SELECT id, chunk_ids, count,
|
||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_id_relation_chunks": """SELECT id, chunk_ids, count,
|
||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_ids_entity_chunks": """SELECT id, chunk_ids, count,
|
||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id = ANY($2)
|
||||
""",
|
||||
"get_by_ids_relation_chunks": """SELECT id, chunk_ids, count,
|
||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2)
|
||||
""",
|
||||
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
||||
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
|
|
@ -5001,6 +5038,22 @@ SQL_TEMPLATES = {
|
|||
count=EXCLUDED.count,
|
||||
update_time = EXCLUDED.update_time
|
||||
""",
|
||||
"upsert_entity_chunks": """INSERT INTO LIGHTRAG_ENTITY_CHUNKS (workspace, id, chunk_ids, count,
|
||||
create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET chunk_ids=EXCLUDED.chunk_ids,
|
||||
count=EXCLUDED.count,
|
||||
update_time = EXCLUDED.update_time
|
||||
""",
|
||||
"upsert_relation_chunks": """INSERT INTO LIGHTRAG_RELATION_CHUNKS (workspace, id, chunk_ids, count,
|
||||
create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET chunk_ids=EXCLUDED.chunk_ids,
|
||||
count=EXCLUDED.count,
|
||||
update_time = EXCLUDED.update_time
|
||||
""",
|
||||
# SQL for VectorStorage
|
||||
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
||||
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue