Refactor exception handling in MemgraphStorage label methods
(cherry picked from commit 4401f86f07)
This commit is contained in:
parent
ed79218550
commit
dcf88a8273
1 changed files with 32 additions and 80 deletions
|
|
@ -8,7 +8,6 @@ import configparser
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -53,7 +52,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
def _get_workspace_label(self) -> str:
|
def _get_workspace_label(self) -> str:
|
||||||
"""Return workspace label (guaranteed non-empty during initialization)"""
|
"""Return workspace label (guaranteed non-empty during initialization)"""
|
||||||
return self._get_composite_workspace()
|
return self.workspace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
async with get_data_init_lock():
|
async with get_data_init_lock():
|
||||||
|
|
@ -134,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
||||||
|
|
@ -147,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
|
|
@ -171,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = (
|
query = (
|
||||||
|
|
@ -191,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
|
|
@ -313,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""
|
query = f"""
|
||||||
|
|
@ -329,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
|
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
|
|
@ -353,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
results = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||||
|
|
@ -390,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await results.consume() # Ensure results are consumed even on error
|
if results is not None:
|
||||||
|
await (
|
||||||
|
results.consume()
|
||||||
|
) # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -420,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""
|
query = f"""
|
||||||
|
|
@ -452,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
|
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
|
|
@ -784,79 +803,6 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all nodes that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids: List of chunk IDs to find associated nodes for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
|
||||||
An empty list if no matching nodes are found.
|
|
||||||
"""
|
|
||||||
if self._driver is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
|
||||||
)
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (n:`{workspace_label}`)
|
|
||||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
|
||||||
RETURN DISTINCT n
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
nodes = []
|
|
||||||
async for record in result:
|
|
||||||
node = record["n"]
|
|
||||||
node_dict = dict(node)
|
|
||||||
node_dict["id"] = node_dict.get("entity_id")
|
|
||||||
nodes.append(node_dict)
|
|
||||||
await result.consume()
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
"""Get all edges that are associated with the given chunk_ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_ids: List of chunk IDs to find associated edges for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
|
||||||
An empty list if no matching edges are found.
|
|
||||||
"""
|
|
||||||
if self._driver is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
|
||||||
)
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
|
||||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
|
||||||
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
|
|
||||||
// Ensure we only return each unique edge once by ordering the source and target
|
|
||||||
WITH a, b, r,
|
|
||||||
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
|
|
||||||
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
|
|
||||||
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
edges = []
|
|
||||||
async for record in result:
|
|
||||||
edge_properties = record["properties"]
|
|
||||||
edge_properties["source"] = record["source"]
|
|
||||||
edge_properties["target"] = record["target"]
|
|
||||||
edges.append(edge_properties)
|
|
||||||
await result.consume()
|
|
||||||
return edges
|
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
|
|
@ -1104,6 +1050,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
|
|
@ -1130,6 +1077,8 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
|
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
|
||||||
|
if result is not None:
|
||||||
|
await result.consume()
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||||
|
|
@ -1152,6 +1101,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
if not query_lower:
|
if not query_lower:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
|
|
@ -1185,4 +1135,6 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
|
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
|
||||||
|
if result is not None:
|
||||||
|
await result.consume()
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue