diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c974f385..e7119df0 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -30,6 +30,7 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger from ..constants import GRAPH_FIELD_SEP +from ..kg.shared_storage import get_graph_db_lock import pipmaster as pm @@ -1220,9 +1221,6 @@ class PostgreSQLDB: with_age: bool = False, graph_name: str | None = None, ) -> dict[str, Any] | None | list[dict[str, Any]]: - # start_time = time.time() - # logger.info(f"PostgreSQL, Querying:\n{sql}") - async with self.pool.acquire() as connection: # type: ignore if with_age and graph_name: await self.configure_age(connection, graph_name) # type: ignore @@ -1248,10 +1246,6 @@ class PostgreSQLDB: else: data = None - # query_time = time.time() - start_time - # logger.info(f"PostgreSQL, Query result len: {len(data)}") - # logger.info(f"PostgreSQL, Query execution time: {query_time:.4f}s") - return data except Exception as e: logger.error(f"PostgreSQL database, error:{e}") @@ -1414,18 +1408,16 @@ class PGKVStorage(BaseKVStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use "default" for compatibility (lowest priority) - final_workspace = "default" - self.db.workspace = final_workspace + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" async def finalize(self): if self.db is not None: @@ -1441,11 +1433,13 @@ class PGKVStorage(BaseKVStorage): """ table_name = namespace_to_table_name(self.namespace) if not table_name: - logger.error(f"Unknown namespace for get_all: {self.namespace}") + logger.error( + f"[{self.workspace}] Unknown namespace for get_all: {self.namespace}" + ) return {} sql = f"SELECT * FROM {table_name} WHERE workspace=$1" - params = {"workspace": self.db.workspace} + params = {"workspace": self.workspace} try: results = await self.db.query(sql, params, multirows=True) @@ -1533,13 +1527,15 @@ class PGKVStorage(BaseKVStorage): # For other namespaces, return as-is return {row["id"]: row for row in results} except Exception as e: - logger.error(f"Error retrieving all data from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving all data from {self.namespace}: {e}" + ) return {} async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"workspace": self.db.workspace, "id": id} + params = {"workspace": self.workspace, "id": id} response = await self.db.query(sql, params) if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): @@ -1619,7 +1615,7 @@ class PGKVStorage(BaseKVStorage): sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) - params = {"workspace": self.db.workspace} + params = {"workspace": self.workspace} results = await self.db.query(sql, params, multirows=True) if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): @@ -1706,7 +1702,7 @@ class PGKVStorage(BaseKVStorage): table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys]), ) - params = {"workspace": self.db.workspace} + params = {"workspace": self.workspace} try: res = await self.db.query(sql, params, multirows=True) if res: @@ -1717,13 +1713,13 @@ class PGKVStorage(BaseKVStorage): return new_keys except Exception as e: logger.error( - f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" + f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" ) raise ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -1733,7 +1729,7 @@ class PGKVStorage(BaseKVStorage): for k, v in data.items(): upsert_sql = SQL_TEMPLATES["upsert_text_chunk"] _data = { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": k, "tokens": v["tokens"], "chunk_order_index": v["chunk_order_index"], @@ -1751,14 +1747,14 @@ class PGKVStorage(BaseKVStorage): _data = { "id": k, "content": v["content"], - "workspace": self.db.workspace, + "workspace": self.workspace, } await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): for k, v in data.items(): upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] _data = { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": k, # Use flattened key as id "original_prompt": v["original_prompt"], "return_value": v["return"], @@ -1778,7 +1774,7 @@ class PGKVStorage(BaseKVStorage): for k, v in data.items(): upsert_sql = SQL_TEMPLATES["upsert_full_entities"] _data = { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": k, "entity_names": json.dumps(v["entity_names"]), "count": v["count"], @@ -1792,7 +1788,7 @@ class PGKVStorage(BaseKVStorage): for k, v in data.items(): upsert_sql = SQL_TEMPLATES["upsert_full_relations"] _data = { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": k, "relation_pairs": json.dumps(v["relation_pairs"]), "count": v["count"], @@ -1819,20 +1815,22 @@ class PGKVStorage(BaseKVStorage): table_name = namespace_to_table_name(self.namespace) if not table_name: - logger.error(f"Unknown namespace for deletion: {self.namespace}") + logger.error( + f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}" + ) return delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" try: - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "ids": ids} - ) + await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids}) logger.debug( - f"Successfully deleted {len(ids)} records from {self.namespace}" + f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}" ) except Exception as e: - logger.error(f"Error while deleting records from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}" + ) async def drop(self) -> dict[str, str]: """Drop the storage""" @@ -1847,7 +1845,7 @@ class PGKVStorage(BaseKVStorage): drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( table_name=table_name ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + await self.db.execute(drop_sql, {"workspace": self.workspace}) return {"status": "success", "message": "data dropped"} except Exception as e: return {"status": "error", "message": str(e)} @@ -1871,18 +1869,16 @@ class PGVectorStorage(BaseVectorStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use "default" for compatibility (lowest priority) - final_workspace = "default" - self.db.workspace = final_workspace + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" async def finalize(self): if self.db is not None: @@ -1895,7 +1891,7 @@ class PGVectorStorage(BaseVectorStorage): try: upsert_sql = SQL_TEMPLATES["upsert_chunk"] data: dict[str, Any] = { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": item["__id__"], "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], @@ -1907,7 +1903,9 @@ class PGVectorStorage(BaseVectorStorage): "update_time": current_time, } except Exception as e: - logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}") + logger.error( + f"[{self.workspace}] Error to prepare upsert,\nsql: {e}\nitem: {item}" + ) raise return upsert_sql, data @@ -1923,7 +1921,7 @@ class PGVectorStorage(BaseVectorStorage): chunk_ids = [source_id] data: dict[str, Any] = { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": item["__id__"], "entity_name": item["entity_name"], "content": item["content"], @@ -1946,7 +1944,7 @@ class PGVectorStorage(BaseVectorStorage): chunk_ids = [source_id] data: dict[str, Any] = { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": item["__id__"], "source_id": item["src_id"], "target_id": item["tgt_id"], @@ -1960,7 +1958,7 @@ class PGVectorStorage(BaseVectorStorage): return upsert_sql, data async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -2009,7 +2007,7 @@ class PGVectorStorage(BaseVectorStorage): # Use parameterized document IDs (None means search across all documents) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) params = { - "workspace": self.db.workspace, + "workspace": self.workspace, "doc_ids": ids, "closer_than_threshold": 1 - self.cosine_better_than_threshold, "top_k": top_k, @@ -2032,20 +2030,22 @@ class PGVectorStorage(BaseVectorStorage): table_name = namespace_to_table_name(self.namespace) if not table_name: - logger.error(f"Unknown namespace for vector deletion: {self.namespace}") + logger.error( + f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}" + ) return delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" try: - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "ids": ids} - ) + await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids}) logger.debug( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" + f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}" ) except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}" + ) async def delete_entity(self, entity_name: str) -> None: """Delete an entity by its name from the vector storage. @@ -2059,11 +2059,13 @@ class PGVectorStorage(BaseVectorStorage): WHERE workspace=$1 AND entity_name=$2""" await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} + delete_sql, {"workspace": self.workspace, "entity_name": entity_name} + ) + logger.debug( + f"[{self.workspace}] Successfully deleted entity {entity_name}" ) - logger.debug(f"Successfully deleted entity {entity_name}") except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") + logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity. @@ -2077,11 +2079,15 @@ class PGVectorStorage(BaseVectorStorage): WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} + delete_sql, {"workspace": self.workspace, "entity_name": entity_name} + ) + logger.debug( + f"[{self.workspace}] Successfully deleted relations for entity {entity_name}" ) - logger.debug(f"Successfully deleted relations for entity {entity_name}") except Exception as e: - logger.error(f"Error deleting relations for entity {entity_name}: {e}") + logger.error( + f"[{self.workspace}] Error deleting relations for entity {entity_name}: {e}" + ) async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get vector data by its ID @@ -2094,11 +2100,13 @@ class PGVectorStorage(BaseVectorStorage): """ table_name = namespace_to_table_name(self.namespace) if not table_name: - logger.error(f"Unknown namespace for ID lookup: {self.namespace}") + logger.error( + f"[{self.workspace}] Unknown namespace for ID lookup: {self.namespace}" + ) return None query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2" - params = {"workspace": self.db.workspace, "id": id} + params = {"workspace": self.workspace, "id": id} try: result = await self.db.query(query, params) @@ -2106,7 +2114,9 @@ class PGVectorStorage(BaseVectorStorage): return dict(result) return None except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}" + ) return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -2123,18 +2133,22 @@ class PGVectorStorage(BaseVectorStorage): table_name = namespace_to_table_name(self.namespace) if not table_name: - logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") + logger.error( + f"[{self.workspace}] Unknown namespace for IDs lookup: {self.namespace}" + ) return [] ids_str = ",".join([f"'{id}'" for id in ids]) query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" - params = {"workspace": self.db.workspace} + params = {"workspace": self.workspace} try: results = await self.db.query(query, params, multirows=True) return [dict(record) for record in results] except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" + ) return [] async def drop(self) -> dict[str, str]: @@ -2150,7 +2164,7 @@ class PGVectorStorage(BaseVectorStorage): drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( table_name=table_name ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + await self.db.execute(drop_sql, {"workspace": self.workspace}) return {"status": "success", "message": "data dropped"} except Exception as e: return {"status": "error", "message": str(e)} @@ -2174,18 +2188,16 @@ class PGDocStatusStorage(DocStatusStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use "default" for compatibility (lowest priority) - final_workspace = "default" - self.db.workspace = final_workspace + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" async def finalize(self): if self.db is not None: @@ -2198,7 +2210,7 @@ class PGDocStatusStorage(DocStatusStorage): table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys]), ) - params = {"workspace": self.db.workspace} + params = {"workspace": self.workspace} try: res = await self.db.query(sql, params, multirows=True) if res: @@ -2211,13 +2223,13 @@ class PGDocStatusStorage(DocStatusStorage): return new_keys except Exception as e: logger.error( - f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" + f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" ) raise async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" - params = {"workspace": self.db.workspace, "id": id} + params = {"workspace": self.workspace, "id": id} result = await self.db.query(sql, params, True) if result is None or result == []: return None @@ -2262,7 +2274,7 @@ class PGDocStatusStorage(DocStatusStorage): return [] sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" - params = {"workspace": self.db.workspace, "ids": ids} + params = {"workspace": self.workspace, "ids": ids} results = await self.db.query(sql, params, True) @@ -2315,7 +2327,7 @@ class PGDocStatusStorage(DocStatusStorage): FROM LIGHTRAG_DOC_STATUS where workspace=$1 GROUP BY STATUS """ - result = await self.db.query(sql, {"workspace": self.db.workspace}, True) + result = await self.db.query(sql, {"workspace": self.workspace}, True) counts = {} for doc in result: counts[doc["status"]] = doc["count"] @@ -2326,7 +2338,7 @@ class PGDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" - params = {"workspace": self.db.workspace, "status": status.value} + params = {"workspace": self.workspace, "status": status.value} result = await self.db.query(sql, params, True) docs_by_status = {} @@ -2380,7 +2392,7 @@ class PGDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific track_id""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2" - params = {"workspace": self.db.workspace, "track_id": track_id} + params = {"workspace": self.workspace, "track_id": track_id} result = await self.db.query(sql, params, True) docs_by_track_id = {} @@ -2468,7 +2480,7 @@ class PGDocStatusStorage(DocStatusStorage): # Build WHERE clause where_clause = "WHERE workspace=$1" - params = {"workspace": self.db.workspace} + params = {"workspace": self.workspace} param_count = 1 if status_filter is not None: @@ -2550,7 +2562,7 @@ class PGDocStatusStorage(DocStatusStorage): WHERE workspace=$1 GROUP BY status """ - params = {"workspace": self.db.workspace} + params = {"workspace": self.workspace} result = await self.db.query(sql, params, True) counts = {} @@ -2582,20 +2594,22 @@ class PGDocStatusStorage(DocStatusStorage): table_name = namespace_to_table_name(self.namespace) if not table_name: - logger.error(f"Unknown namespace for deletion: {self.namespace}") + logger.error( + f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}" + ) return delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" try: - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "ids": ids} - ) + await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids}) logger.debug( - f"Successfully deleted {len(ids)} records from {self.namespace}" + f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}" ) except Exception as e: - logger.error(f"Error while deleting records from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}" + ) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Update or insert document status @@ -2603,7 +2617,7 @@ class PGDocStatusStorage(DocStatusStorage): Args: data: dictionary of document IDs and their status data """ - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -2629,7 +2643,9 @@ class PGDocStatusStorage(DocStatusStorage): # Convert to UTC and remove timezone info for storage return dt.astimezone(timezone.utc).replace(tzinfo=None) except (ValueError, TypeError): - logger.warning(f"Unable to parse datetime string: {dt_str}") + logger.warning( + f"[{self.workspace}] Unable to parse datetime string: {dt_str}" + ) return None # Modified SQL to include created_at, updated_at, chunks_list, track_id, metadata, and error_msg in both INSERT and UPDATE operations @@ -2657,7 +2673,7 @@ class PGDocStatusStorage(DocStatusStorage): await self.db.execute( sql, { - "workspace": self.db.workspace, + "workspace": self.workspace, "id": k, "content_summary": v["content_summary"], "content_length": v["content_length"], @@ -2688,7 +2704,7 @@ class PGDocStatusStorage(DocStatusStorage): drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( table_name=table_name ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + await self.db.execute(drop_sql, {"workspace": self.workspace}) return {"status": "success", "message": "data dropped"} except Exception as e: return {"status": "error", "message": str(e)} @@ -2732,7 +2748,7 @@ class PGGraphStorage(BaseGraphStorage): Returns: str: The graph name for the current workspace """ - workspace = getattr(self, "workspace", None) + workspace = self.workspace namespace = self.namespace if workspace and workspace.strip() and workspace.strip().lower() != "default": @@ -2741,7 +2757,7 @@ class PGGraphStorage(BaseGraphStorage): safe_namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) return f"{safe_workspace}_{safe_namespace}" else: - # When workspace is empty or "default", use namespace directly + # When the workspace is "default", use the namespace directly (for backward compatibility with legacy implementations) return re.sub(r"[^a-zA-Z0-9_]", "_", namespace) @staticmethod @@ -2764,64 +2780,64 @@ class PGGraphStorage(BaseGraphStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > None - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use None for compatibility (lowest priority) - final_workspace = None - self.db.workspace = final_workspace + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" # Dynamically generate graph name based on workspace - self.workspace = self.db.workspace self.graph_name = self._get_workspace_graph_name() # Log the graph initialization for debugging logger.info( - f"PostgreSQL Graph initialized: workspace='{self.workspace}', graph_name='{self.graph_name}'" + f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'" ) - # Create AGE extension and configure graph environment once at initialization - async with self.db.pool.acquire() as connection: - # First ensure AGE extension is created - await PostgreSQLDB.configure_age_extension(connection) + # Use graph database lock to ensure atomic operations and prevent deadlocks + graph_db_lock = get_graph_db_lock(enable_logging=False) + async with graph_db_lock: + # Create AGE extension and configure graph environment once at initialization + async with self.db.pool.acquire() as connection: + # First ensure AGE extension is created + await PostgreSQLDB.configure_age_extension(connection) - # Execute each statement separately and ignore errors - queries = [ - f"SELECT create_graph('{self.graph_name}')", - f"SELECT create_vlabel('{self.graph_name}', 'base');", - f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');", - # f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)', - f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', - # f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)', - f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)', - f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)', - f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)', - f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)', - f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)', - f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)', - f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)', - f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)', - f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', - f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)', - f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx', - ] + # Execute each statement separately and ignore errors + queries = [ + f"SELECT create_graph('{self.graph_name}')", + f"SELECT create_vlabel('{self.graph_name}', 'base');", + f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');", + # f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)', + f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', + # f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)', + f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)', + f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)', + f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)', + f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)', + f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)', + f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)', + f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)', + f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)', + f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', + f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)', + f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx', + ] - for query in queries: - # Use the new flag to silently ignore "already exists" errors - # at the source, preventing log spam. - await self.db.execute( - query, - upsert=True, - ignore_if_exists=True, # Pass the new flag - with_age=True, - graph_name=self.graph_name, - ) + for query in queries: + # Use the new flag to silently ignore "already exists" errors + # at the source, preventing log spam. + await self.db.execute( + query, + upsert=True, + ignore_if_exists=True, # Pass the new flag + with_age=True, + graph_name=self.graph_name, + ) async def finalize(self): if self.db is not None: @@ -3067,7 +3083,9 @@ class PGGraphStorage(BaseGraphStorage): try: node_dict = json.loads(node_dict) except json.JSONDecodeError: - logger.warning(f"Failed to parse node string: {node_dict}") + logger.warning( + f"[{self.workspace}] Failed to parse node string: {node_dict}" + ) return node_dict return None @@ -3122,7 +3140,9 @@ class PGGraphStorage(BaseGraphStorage): try: result = json.loads(result) except json.JSONDecodeError: - logger.warning(f"Failed to parse edge string: {result}") + logger.warning( + f"[{self.workspace}] Failed to parse edge string: {result}" + ) return result @@ -3188,7 +3208,9 @@ class PGGraphStorage(BaseGraphStorage): await self._query(query, readonly=False, upsert=True) except Exception: - logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`") + logger.error( + f"[{self.workspace}] POSTGRES, upsert_node error on node_id: `{node_id}`" + ) raise @retry( @@ -3232,7 +3254,7 @@ class PGGraphStorage(BaseGraphStorage): except Exception: logger.error( - f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`" + f"[{self.workspace}] POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`" ) raise @@ -3253,7 +3275,7 @@ class PGGraphStorage(BaseGraphStorage): try: await self._query(query, readonly=False) except Exception as e: - logger.error("Error during node deletion: {%s}", e) + logger.error(f"[{self.workspace}] Error during node deletion: {e}") raise async def remove_nodes(self, node_ids: list[str]) -> None: @@ -3275,7 +3297,7 @@ class PGGraphStorage(BaseGraphStorage): try: await self._query(query, readonly=False) except Exception as e: - logger.error("Error during node removal: {%s}", e) + logger.error(f"[{self.workspace}] Error during node removal: {e}") raise async def remove_edges(self, edges: list[tuple[str, str]]) -> None: @@ -3296,9 +3318,11 @@ class PGGraphStorage(BaseGraphStorage): try: await self._query(query, readonly=False) - logger.debug(f"Deleted edge from '{source}' to '{target}'") + logger.debug( + f"[{self.workspace}] Deleted edge from '{source}' to '{target}'" + ) except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") + logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") raise async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: @@ -3339,7 +3363,7 @@ class PGGraphStorage(BaseGraphStorage): node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning( - f"Failed to parse node string in batch: {node_dict}" + f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" ) # Remove the 'base' label if present in a 'labels' property @@ -3502,7 +3526,7 @@ class PGGraphStorage(BaseGraphStorage): edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( - f"Failed to parse edge properties string: {edge_props}" + f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" ) continue @@ -3518,7 +3542,7 @@ class PGGraphStorage(BaseGraphStorage): edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( - f"Failed to parse edge properties string: {edge_props}" + f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" ) continue @@ -3640,7 +3664,7 @@ class PGGraphStorage(BaseGraphStorage): node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning( - f"Failed to parse node string in batch: {node_dict}" + f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" ) node_dict["id"] = node_dict["entity_id"] @@ -3675,7 +3699,7 @@ class PGGraphStorage(BaseGraphStorage): edge_agtype = json.loads(edge_agtype) except json.JSONDecodeError: logger.warning( - f"Failed to parse edge string in batch: {edge_agtype}" + f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}" ) source_agtype = item["source"]["properties"] @@ -3685,7 +3709,7 @@ class PGGraphStorage(BaseGraphStorage): source_agtype = json.loads(source_agtype) except json.JSONDecodeError: logger.warning( - f"Failed to parse node string in batch: {source_agtype}" + f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}" ) target_agtype = item["target"]["properties"] @@ -3695,7 +3719,7 @@ class PGGraphStorage(BaseGraphStorage): target_agtype = json.loads(target_agtype) except json.JSONDecodeError: logger.warning( - f"Failed to parse node string in batch: {target_agtype}" + f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}" ) if edge_agtype and source_agtype and target_agtype: @@ -3964,7 +3988,9 @@ class PGGraphStorage(BaseGraphStorage): node_ids = [str(result["node_id"]) for result in node_results] - logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}") + logger.info( + f"[{self.workspace}] Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}" + ) if node_ids: formatted_ids = ", ".join(node_ids) @@ -4028,13 +4054,13 @@ class PGGraphStorage(BaseGraphStorage): kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) logger.info( - f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" + f"[{self.workspace}] Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" ) else: # For non-wildcard queries, use the BFS algorithm kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) logger.info( - f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" + f"[{self.workspace}] Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" ) return kg @@ -4061,7 +4087,9 @@ class PGGraphStorage(BaseGraphStorage): try: node_dict = json.loads(node_dict) except json.JSONDecodeError: - logger.warning(f"Failed to parse node string: {node_dict}") + logger.warning( + f"[{self.workspace}] Failed to parse node string: {node_dict}" + ) # Add node id (entity_id) to the dictionary for easier access node_dict["id"] = node_dict.get("entity_id") @@ -4091,7 +4119,7 @@ class PGGraphStorage(BaseGraphStorage): edge_properties = json.loads(edge_properties) except json.JSONDecodeError: logger.warning( - f"Failed to parse edge properties string: {edge_properties}" + f"[{self.workspace}] Failed to parse edge properties string: {edge_properties}" ) edge_properties = {} @@ -4114,7 +4142,7 @@ class PGGraphStorage(BaseGraphStorage): "message": f"workspace '{self.workspace}' graph data dropped", } except Exception as e: - logger.error(f"Error dropping graph: {e}") + logger.error(f"[{self.workspace}] Error dropping graph: {e}") return {"status": "error", "message": str(e)}