Refactor exception handling in MemgraphStorage label methods

(cherry picked from commit 4401f86f07)
This commit is contained in:
yangdx 2025-11-14 11:01:26 +08:00 committed by Raphaël MANSUY
parent ed79218550
commit dcf88a8273

View file

@ -8,7 +8,6 @@ import configparser
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
@ -53,7 +52,7 @@ class MemgraphStorage(BaseGraphStorage):
def _get_workspace_label(self) -> str: def _get_workspace_label(self) -> str:
"""Return workspace label (guaranteed non-empty during initialization)""" """Return workspace label (guaranteed non-empty during initialization)"""
return self._get_composite_workspace() return self.workspace
async def initialize(self): async def initialize(self):
async with get_data_init_lock(): async with get_data_init_lock():
@ -134,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
@ -147,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
@ -171,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = ( query = (
@ -191,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
@ -313,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f""" query = f"""
@ -329,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage):
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -353,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
results = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -390,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
await results.consume() # Ensure results are consumed even on error if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -420,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f""" query = f"""
@ -452,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
@ -784,79 +803,6 @@ class MemgraphStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated nodes for
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated edges for
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target
WITH a, b, r,
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
async def get_knowledge_graph( async def get_knowledge_graph(
self, self,
node_label: str, node_label: str,
@ -1104,6 +1050,7 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
async with self._driver.session( async with self._driver.session(
@ -1130,6 +1077,8 @@ class MemgraphStorage(BaseGraphStorage):
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}") logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
if result is not None:
await result.consume()
return [] return []
async def search_labels(self, query: str, limit: int = 50) -> list[str]: async def search_labels(self, query: str, limit: int = 50) -> list[str]:
@ -1152,6 +1101,7 @@ class MemgraphStorage(BaseGraphStorage):
if not query_lower: if not query_lower:
return [] return []
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
async with self._driver.session( async with self._driver.session(
@ -1185,4 +1135,6 @@ class MemgraphStorage(BaseGraphStorage):
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
if result is not None:
await result.consume()
return [] return []