diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index d81c2ebd..c9a96064 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -133,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" @@ -146,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" ) - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: @@ -170,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = ( @@ -190,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -312,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = f""" @@ -328,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage): return labels except Exception as e: logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: @@ -352,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + results = None try: workspace_label = self._get_workspace_label() query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) @@ -389,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) - await results.consume() # Ensure results are consumed even on error + if results is not None: + await ( + results.consume() + ) # Ensure results are consumed even on error raise except Exception as e: logger.error( @@ -419,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = f""" @@ -451,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" ) - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: @@ -1030,11 +1050,12 @@ class MemgraphStorage(BaseGraphStorage): "Memgraph driver is not initialized. Call 'await initialize()' first." ) - try: - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + result = None + try: query = f""" MATCH (n:`{workspace_label}`) WHERE n.entity_id IS NOT NULL @@ -1054,9 +1075,13 @@ class MemgraphStorage(BaseGraphStorage): f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})" ) return labels - except Exception as e: - logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}") - return [] + except Exception as e: + logger.error( + f"[{self.workspace}] Error getting popular labels: {str(e)}" + ) + if result is not None: + await result.consume() + return [] async def search_labels(self, query: str, limit: int = 50) -> list[str]: """Search labels with fuzzy matching @@ -1078,11 +1103,12 @@ class MemgraphStorage(BaseGraphStorage): if not query_lower: return [] - try: - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + result = None + try: cypher_query = f""" MATCH (n:`{workspace_label}`) WHERE n.entity_id IS NOT NULL @@ -1109,6 +1135,8 @@ class MemgraphStorage(BaseGraphStorage): f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})" ) return labels - except Exception as e: - logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") - return [] + except Exception as e: + logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") + if result is not None: + await result.consume() + return [] diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 76fa11f2..31df4623 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -371,6 +371,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) @@ -381,7 +382,8 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" ) - await result.consume() # Ensure results are consumed even on error + if result is not None: + await result.consume() # Ensure results are consumed even on error raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: @@ -403,6 +405,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: query = ( f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " @@ -420,7 +423,8 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) - await result.consume() # Ensure results are consumed even on error + if result is not None: + await result.consume() # Ensure results are consumed even on error raise async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -799,6 +803,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + results = None try: workspace_label = self._get_workspace_label() query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) @@ -836,7 +841,10 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) - await results.consume() # Ensure results are consumed even on error + if results is not None: + await ( + results.consume() + ) # Ensure results are consumed even on error raise except Exception as e: logger.error( @@ -1592,6 +1600,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: query = f""" MATCH (n:`{workspace_label}`) @@ -1616,7 +1625,8 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting popular labels: {str(e)}" ) - await result.consume() + if result is not None: + await result.consume() raise async def search_labels(self, query: str, limit: int = 50) -> list[str]: