Fix null reference errors in graph database error handling

- Initialize result vars to None
- Add null checks before consume calls
- Prevent crashes in except blocks
- Apply fix to both Neo4J and Memgraph
This commit is contained in:
yangdx 2025-11-14 10:39:04 +08:00
parent 2f2f35b883
commit 423e4e927a
2 changed files with 63 additions and 25 deletions

View file

@ -133,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
@ -146,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
@ -170,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = ( query = (
@ -190,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
@ -312,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f""" query = f"""
@ -328,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage):
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -352,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
results = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -389,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
await results.consume() # Ensure results are consumed even on error if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -419,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f""" query = f"""
@ -451,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
@ -1030,11 +1050,12 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
try: workspace_label = self._get_workspace_label()
workspace_label = self._get_workspace_label() async with self._driver.session(
async with self._driver.session( database=self._DATABASE, default_access_mode="READ"
database=self._DATABASE, default_access_mode="READ" ) as session:
) as session: result = None
try:
query = f""" query = f"""
MATCH (n:`{workspace_label}`) MATCH (n:`{workspace_label}`)
WHERE n.entity_id IS NOT NULL WHERE n.entity_id IS NOT NULL
@ -1054,9 +1075,13 @@ class MemgraphStorage(BaseGraphStorage):
f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})" f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
) )
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}") logger.error(
return [] f"[{self.workspace}] Error getting popular labels: {str(e)}"
)
if result is not None:
await result.consume()
return []
async def search_labels(self, query: str, limit: int = 50) -> list[str]: async def search_labels(self, query: str, limit: int = 50) -> list[str]:
"""Search labels with fuzzy matching """Search labels with fuzzy matching
@ -1078,11 +1103,12 @@ class MemgraphStorage(BaseGraphStorage):
if not query_lower: if not query_lower:
return [] return []
try: workspace_label = self._get_workspace_label()
workspace_label = self._get_workspace_label() async with self._driver.session(
async with self._driver.session( database=self._DATABASE, default_access_mode="READ"
database=self._DATABASE, default_access_mode="READ" ) as session:
) as session: result = None
try:
cypher_query = f""" cypher_query = f"""
MATCH (n:`{workspace_label}`) MATCH (n:`{workspace_label}`)
WHERE n.entity_id IS NOT NULL WHERE n.entity_id IS NOT NULL
@ -1109,6 +1135,8 @@ class MemgraphStorage(BaseGraphStorage):
f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})" f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
) )
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
return [] if result is not None:
await result.consume()
return []

View file

@ -371,6 +371,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id) result = await session.run(query, entity_id=node_id)
@ -381,7 +382,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
) )
await result.consume() # Ensure results are consumed even on error if result is not None:
await result.consume() # Ensure results are consumed even on error
raise raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
@ -403,6 +405,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = ( query = (
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
@ -420,7 +423,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure results are consumed even on error if result is not None:
await result.consume() # Ensure results are consumed even on error
raise raise
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
@ -799,6 +803,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
results = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -836,7 +841,10 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
await results.consume() # Ensure results are consumed even on error if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -1592,6 +1600,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = f""" query = f"""
MATCH (n:`{workspace_label}`) MATCH (n:`{workspace_label}`)
@ -1616,7 +1625,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting popular labels: {str(e)}" f"[{self.workspace}] Error getting popular labels: {str(e)}"
) )
await result.consume() if result is not None:
await result.consume()
raise raise
async def search_labels(self, query: str, limit: int = 50) -> list[str]: async def search_labels(self, query: str, limit: int = 50) -> list[str]: