Fix: Resolve workspace isolation problem for PostgreSQL with multiple LightRAG instances

This commit is contained in:
yangdx 2025-08-12 01:27:05 +08:00
parent d9c1f935f5
commit ca00b9c8ee

View file

@ -30,6 +30,7 @@ from ..base import (
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_graph_db_lock
import pipmaster as pm
@ -1220,9 +1221,6 @@ class PostgreSQLDB:
with_age: bool = False,
graph_name: str | None = None,
) -> dict[str, Any] | None | list[dict[str, Any]]:
# start_time = time.time()
# logger.info(f"PostgreSQL, Querying:\n{sql}")
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name) # type: ignore
@ -1248,10 +1246,6 @@ class PostgreSQLDB:
else:
data = None
# query_time = time.time() - start_time
# logger.info(f"PostgreSQL, Query result len: {len(data)}")
# logger.info(f"PostgreSQL, Query execution time: {query_time:.4f}s")
return data
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
@ -1414,18 +1408,16 @@ class PGKVStorage(BaseKVStorage):
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
final_workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
final_workspace = self.workspace
self.db.workspace = final_workspace
else:
# Use "default" for compatibility (lowest priority)
final_workspace = "default"
self.db.workspace = final_workspace
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
pass
else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
async def finalize(self):
if self.db is not None:
@ -1441,11 +1433,13 @@ class PGKVStorage(BaseKVStorage):
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for get_all: {self.namespace}")
logger.error(
f"[{self.workspace}] Unknown namespace for get_all: {self.namespace}"
)
return {}
sql = f"SELECT * FROM {table_name} WHERE workspace=$1"
params = {"workspace": self.db.workspace}
params = {"workspace": self.workspace}
try:
results = await self.db.query(sql, params, multirows=True)
@ -1533,13 +1527,15 @@ class PGKVStorage(BaseKVStorage):
# For other namespaces, return as-is
return {row["id"]: row for row in results}
except Exception as e:
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving all data from {self.namespace}: {e}"
)
return {}
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
params = {"workspace": self.workspace, "id": id}
response = await self.db.query(sql, params)
if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@ -1619,7 +1615,7 @@ class PGKVStorage(BaseKVStorage):
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
params = {"workspace": self.workspace}
results = await self.db.query(sql, params, multirows=True)
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@ -1706,7 +1702,7 @@ class PGKVStorage(BaseKVStorage):
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.db.workspace}
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, params, multirows=True)
if res:
@ -1717,13 +1713,13 @@ class PGKVStorage(BaseKVStorage):
return new_keys
except Exception as e:
logger.error(
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
)
raise
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -1733,7 +1729,7 @@ class PGKVStorage(BaseKVStorage):
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_text_chunk"]
_data = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": k,
"tokens": v["tokens"],
"chunk_order_index": v["chunk_order_index"],
@ -1751,14 +1747,14 @@ class PGKVStorage(BaseKVStorage):
_data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
"workspace": self.workspace,
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": k, # Use flattened key as id
"original_prompt": v["original_prompt"],
"return_value": v["return"],
@ -1778,7 +1774,7 @@ class PGKVStorage(BaseKVStorage):
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_full_entities"]
_data = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": k,
"entity_names": json.dumps(v["entity_names"]),
"count": v["count"],
@ -1792,7 +1788,7 @@ class PGKVStorage(BaseKVStorage):
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_full_relations"]
_data = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": k,
"relation_pairs": json.dumps(v["relation_pairs"]),
"count": v["count"],
@ -1819,20 +1815,22 @@ class PGKVStorage(BaseKVStorage):
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for deletion: {self.namespace}")
logger.error(
f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}"
)
return
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
try:
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "ids": ids}
)
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
logger.debug(
f"Successfully deleted {len(ids)} records from {self.namespace}"
f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting records from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}"
)
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
@ -1847,7 +1845,7 @@ class PGKVStorage(BaseKVStorage):
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@ -1871,18 +1869,16 @@ class PGVectorStorage(BaseVectorStorage):
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
final_workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
final_workspace = self.workspace
self.db.workspace = final_workspace
else:
# Use "default" for compatibility (lowest priority)
final_workspace = "default"
self.db.workspace = final_workspace
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
pass
else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
async def finalize(self):
if self.db is not None:
@ -1895,7 +1891,7 @@ class PGVectorStorage(BaseVectorStorage):
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
data: dict[str, Any] = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": item["__id__"],
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
@ -1907,7 +1903,9 @@ class PGVectorStorage(BaseVectorStorage):
"update_time": current_time,
}
except Exception as e:
logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}")
logger.error(
f"[{self.workspace}] Error to prepare upsert,\nsql: {e}\nitem: {item}"
)
raise
return upsert_sql, data
@ -1923,7 +1921,7 @@ class PGVectorStorage(BaseVectorStorage):
chunk_ids = [source_id]
data: dict[str, Any] = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": item["__id__"],
"entity_name": item["entity_name"],
"content": item["content"],
@ -1946,7 +1944,7 @@ class PGVectorStorage(BaseVectorStorage):
chunk_ids = [source_id]
data: dict[str, Any] = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": item["__id__"],
"source_id": item["src_id"],
"target_id": item["tgt_id"],
@ -1960,7 +1958,7 @@ class PGVectorStorage(BaseVectorStorage):
return upsert_sql, data
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -2009,7 +2007,7 @@ class PGVectorStorage(BaseVectorStorage):
# Use parameterized document IDs (None means search across all documents)
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
params = {
"workspace": self.db.workspace,
"workspace": self.workspace,
"doc_ids": ids,
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k,
@ -2032,20 +2030,22 @@ class PGVectorStorage(BaseVectorStorage):
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
logger.error(
f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}"
)
return
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
try:
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "ids": ids}
)
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by its name from the vector storage.
@ -2059,11 +2059,13 @@ class PGVectorStorage(BaseVectorStorage):
WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
delete_sql, {"workspace": self.workspace, "entity_name": entity_name}
)
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity.
@ -2077,11 +2079,15 @@ class PGVectorStorage(BaseVectorStorage):
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
delete_sql, {"workspace": self.workspace, "entity_name": entity_name}
)
logger.debug(
f"[{self.workspace}] Successfully deleted relations for entity {entity_name}"
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
logger.error(
f"[{self.workspace}] Error deleting relations for entity {entity_name}: {e}"
)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
@ -2094,11 +2100,13 @@ class PGVectorStorage(BaseVectorStorage):
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
logger.error(
f"[{self.workspace}] Unknown namespace for ID lookup: {self.namespace}"
)
return None
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2"
params = {"workspace": self.db.workspace, "id": id}
params = {"workspace": self.workspace, "id": id}
try:
result = await self.db.query(query, params)
@ -2106,7 +2114,9 @@ class PGVectorStorage(BaseVectorStorage):
return dict(result)
return None
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -2123,18 +2133,22 @@ class PGVectorStorage(BaseVectorStorage):
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
logger.error(
f"[{self.workspace}] Unknown namespace for IDs lookup: {self.namespace}"
)
return []
ids_str = ",".join([f"'{id}'" for id in ids])
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
params = {"workspace": self.db.workspace}
params = {"workspace": self.workspace}
try:
results = await self.db.query(query, params, multirows=True)
return [dict(record) for record in results]
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return []
async def drop(self) -> dict[str, str]:
@ -2150,7 +2164,7 @@ class PGVectorStorage(BaseVectorStorage):
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@ -2174,18 +2188,16 @@ class PGDocStatusStorage(DocStatusStorage):
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
final_workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
final_workspace = self.workspace
self.db.workspace = final_workspace
else:
# Use "default" for compatibility (lowest priority)
final_workspace = "default"
self.db.workspace = final_workspace
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
pass
else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
async def finalize(self):
if self.db is not None:
@ -2198,7 +2210,7 @@ class PGDocStatusStorage(DocStatusStorage):
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.db.workspace}
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, params, multirows=True)
if res:
@ -2211,13 +2223,13 @@ class PGDocStatusStorage(DocStatusStorage):
return new_keys
except Exception as e:
logger.error(
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
)
raise
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
params = {"workspace": self.db.workspace, "id": id}
params = {"workspace": self.workspace, "id": id}
result = await self.db.query(sql, params, True)
if result is None or result == []:
return None
@ -2262,7 +2274,7 @@ class PGDocStatusStorage(DocStatusStorage):
return []
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.db.workspace, "ids": ids}
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, params, True)
@ -2315,7 +2327,7 @@ class PGDocStatusStorage(DocStatusStorage):
FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS
"""
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
result = await self.db.query(sql, {"workspace": self.workspace}, True)
counts = {}
for doc in result:
counts[doc["status"]] = doc["count"]
@ -2326,7 +2338,7 @@ class PGDocStatusStorage(DocStatusStorage):
) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status.value}
params = {"workspace": self.workspace, "status": status.value}
result = await self.db.query(sql, params, True)
docs_by_status = {}
@ -2380,7 +2392,7 @@ class PGDocStatusStorage(DocStatusStorage):
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific track_id"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2"
params = {"workspace": self.db.workspace, "track_id": track_id}
params = {"workspace": self.workspace, "track_id": track_id}
result = await self.db.query(sql, params, True)
docs_by_track_id = {}
@ -2468,7 +2480,7 @@ class PGDocStatusStorage(DocStatusStorage):
# Build WHERE clause
where_clause = "WHERE workspace=$1"
params = {"workspace": self.db.workspace}
params = {"workspace": self.workspace}
param_count = 1
if status_filter is not None:
@ -2550,7 +2562,7 @@ class PGDocStatusStorage(DocStatusStorage):
WHERE workspace=$1
GROUP BY status
"""
params = {"workspace": self.db.workspace}
params = {"workspace": self.workspace}
result = await self.db.query(sql, params, True)
counts = {}
@ -2582,20 +2594,22 @@ class PGDocStatusStorage(DocStatusStorage):
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for deletion: {self.namespace}")
logger.error(
f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}"
)
return
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
try:
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "ids": ids}
)
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
logger.debug(
f"Successfully deleted {len(ids)} records from {self.namespace}"
f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting records from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}"
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Update or insert document status
@ -2603,7 +2617,7 @@ class PGDocStatusStorage(DocStatusStorage):
Args:
data: dictionary of document IDs and their status data
"""
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -2629,7 +2643,9 @@ class PGDocStatusStorage(DocStatusStorage):
# Convert to UTC and remove timezone info for storage
return dt.astimezone(timezone.utc).replace(tzinfo=None)
except (ValueError, TypeError):
logger.warning(f"Unable to parse datetime string: {dt_str}")
logger.warning(
f"[{self.workspace}] Unable to parse datetime string: {dt_str}"
)
return None
# Modified SQL to include created_at, updated_at, chunks_list, track_id, metadata, and error_msg in both INSERT and UPDATE operations
@ -2657,7 +2673,7 @@ class PGDocStatusStorage(DocStatusStorage):
await self.db.execute(
sql,
{
"workspace": self.db.workspace,
"workspace": self.workspace,
"id": k,
"content_summary": v["content_summary"],
"content_length": v["content_length"],
@ -2688,7 +2704,7 @@ class PGDocStatusStorage(DocStatusStorage):
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@ -2732,7 +2748,7 @@ class PGGraphStorage(BaseGraphStorage):
Returns:
str: The graph name for the current workspace
"""
workspace = getattr(self, "workspace", None)
workspace = self.workspace
namespace = self.namespace
if workspace and workspace.strip() and workspace.strip().lower() != "default":
@ -2741,7 +2757,7 @@ class PGGraphStorage(BaseGraphStorage):
safe_namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
return f"{safe_workspace}_{safe_namespace}"
else:
# When workspace is empty or "default", use namespace directly
# When the workspace is "default", use the namespace directly (for backward compatibility with legacy implementations)
return re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
@staticmethod
@ -2764,64 +2780,64 @@ class PGGraphStorage(BaseGraphStorage):
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > None
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
final_workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
final_workspace = self.workspace
self.db.workspace = final_workspace
else:
# Use None for compatibility (lowest priority)
final_workspace = None
self.db.workspace = final_workspace
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
pass
else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
# Dynamically generate graph name based on workspace
self.workspace = self.db.workspace
self.graph_name = self._get_workspace_graph_name()
# Log the graph initialization for debugging
logger.info(
f"PostgreSQL Graph initialized: workspace='{self.workspace}', graph_name='{self.graph_name}'"
f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
)
# Create AGE extension and configure graph environment once at initialization
async with self.db.pool.acquire() as connection:
# First ensure AGE extension is created
await PostgreSQLDB.configure_age_extension(connection)
# Use graph database lock to ensure atomic operations and prevent deadlocks
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
# Create AGE extension and configure graph environment once at initialization
async with self.db.pool.acquire() as connection:
# First ensure AGE extension is created
await PostgreSQLDB.configure_age_extension(connection)
# Execute each statement separately and ignore errors
queries = [
f"SELECT create_graph('{self.graph_name}')",
f"SELECT create_vlabel('{self.graph_name}', 'base');",
f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');",
# f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)',
f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
# f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)',
f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)',
f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)',
f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)',
f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)',
f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)',
f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)',
f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)',
f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)',
f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)',
f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx',
]
# Execute each statement separately and ignore errors
queries = [
f"SELECT create_graph('{self.graph_name}')",
f"SELECT create_vlabel('{self.graph_name}', 'base');",
f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');",
# f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)',
f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
# f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)',
f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)',
f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)',
f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)',
f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)',
f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)',
f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)',
f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)',
f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)',
f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)',
f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx',
]
for query in queries:
# Use the new flag to silently ignore "already exists" errors
# at the source, preventing log spam.
await self.db.execute(
query,
upsert=True,
ignore_if_exists=True, # Pass the new flag
with_age=True,
graph_name=self.graph_name,
)
for query in queries:
# Use the new flag to silently ignore "already exists" errors
# at the source, preventing log spam.
await self.db.execute(
query,
upsert=True,
ignore_if_exists=True, # Pass the new flag
with_age=True,
graph_name=self.graph_name,
)
async def finalize(self):
if self.db is not None:
@ -3067,7 +3083,9 @@ class PGGraphStorage(BaseGraphStorage):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(f"Failed to parse node string: {node_dict}")
logger.warning(
f"[{self.workspace}] Failed to parse node string: {node_dict}"
)
return node_dict
return None
@ -3122,7 +3140,9 @@ class PGGraphStorage(BaseGraphStorage):
try:
result = json.loads(result)
except json.JSONDecodeError:
logger.warning(f"Failed to parse edge string: {result}")
logger.warning(
f"[{self.workspace}] Failed to parse edge string: {result}"
)
return result
@ -3188,7 +3208,9 @@ class PGGraphStorage(BaseGraphStorage):
await self._query(query, readonly=False, upsert=True)
except Exception:
logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`")
logger.error(
f"[{self.workspace}] POSTGRES, upsert_node error on node_id: `{node_id}`"
)
raise
@retry(
@ -3232,7 +3254,7 @@ class PGGraphStorage(BaseGraphStorage):
except Exception:
logger.error(
f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
f"[{self.workspace}] POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
)
raise
@ -3253,7 +3275,7 @@ class PGGraphStorage(BaseGraphStorage):
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during node deletion: {%s}", e)
logger.error(f"[{self.workspace}] Error during node deletion: {e}")
raise
async def remove_nodes(self, node_ids: list[str]) -> None:
@ -3275,7 +3297,7 @@ class PGGraphStorage(BaseGraphStorage):
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during node removal: {%s}", e)
logger.error(f"[{self.workspace}] Error during node removal: {e}")
raise
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
@ -3296,9 +3318,11 @@ class PGGraphStorage(BaseGraphStorage):
try:
await self._query(query, readonly=False)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
logger.debug(
f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
@ -3339,7 +3363,7 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
# Remove the 'base' label if present in a 'labels' property
@ -3502,7 +3526,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
@ -3518,7 +3542,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
@ -3640,7 +3664,7 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
node_dict["id"] = node_dict["entity_id"]
@ -3675,7 +3699,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_agtype = json.loads(edge_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge string in batch: {edge_agtype}"
f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
)
source_agtype = item["source"]["properties"]
@ -3685,7 +3709,7 @@ class PGGraphStorage(BaseGraphStorage):
source_agtype = json.loads(source_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {source_agtype}"
f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
)
target_agtype = item["target"]["properties"]
@ -3695,7 +3719,7 @@ class PGGraphStorage(BaseGraphStorage):
target_agtype = json.loads(target_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {target_agtype}"
f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
)
if edge_agtype and source_agtype and target_agtype:
@ -3964,7 +3988,9 @@ class PGGraphStorage(BaseGraphStorage):
node_ids = [str(result["node_id"]) for result in node_results]
logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}")
logger.info(
f"[{self.workspace}] Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}"
)
if node_ids:
formatted_ids = ", ".join(node_ids)
@ -4028,13 +4054,13 @@ class PGGraphStorage(BaseGraphStorage):
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
f"[{self.workspace}] Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
else:
# For non-wildcard queries, use the BFS algorithm
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
f"[{self.workspace}] Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
return kg
@ -4061,7 +4087,9 @@ class PGGraphStorage(BaseGraphStorage):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(f"Failed to parse node string: {node_dict}")
logger.warning(
f"[{self.workspace}] Failed to parse node string: {node_dict}"
)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
@ -4091,7 +4119,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_properties = json.loads(edge_properties)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_properties}"
f"[{self.workspace}] Failed to parse edge properties string: {edge_properties}"
)
edge_properties = {}
@ -4114,7 +4142,7 @@ class PGGraphStorage(BaseGraphStorage):
"message": f"workspace '{self.workspace}' graph data dropped",
}
except Exception as e:
logger.error(f"Error dropping graph: {e}")
logger.error(f"[{self.workspace}] Error dropping graph: {e}")
return {"status": "error", "message": str(e)}