diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 0800f4a8..e82aceec 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -8,7 +8,6 @@ import configparser from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm @@ -53,7 +52,7 @@ class MemgraphStorage(BaseGraphStorage): def _get_workspace_label(self) -> str: """Return workspace label (guaranteed non-empty during initialization)""" - return self._get_composite_workspace() + return self.workspace async def initialize(self): async with get_data_init_lock(): @@ -134,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" @@ -147,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" ) - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: @@ -171,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = ( @@ -191,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -313,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = f""" @@ -329,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage): return labels except Exception as e: logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: @@ -353,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + results = None try: workspace_label = self._get_workspace_label() query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) @@ -390,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) - await results.consume() # Ensure results are consumed even on error + if results is not None: + await ( + results.consume() + ) # Ensure results are consumed even on error raise except Exception as e: logger.error( @@ -420,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: workspace_label = self._get_workspace_label() query = f""" @@ -452,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" ) - await result.consume() # Ensure the result is consumed even on error + if result is not None: + await ( + result.consume() + ) # Ensure the result is consumed even on error raise async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: @@ -784,79 +803,6 @@ class MemgraphStorage(BaseGraphStorage): degrees = int(src_degree) + int(trg_degree) return degrees - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all nodes that are associated with the given chunk_ids. - - Args: - chunk_ids: List of chunk IDs to find associated nodes for - - Returns: - list[dict]: A list of nodes, where each node is a dictionary of its properties. - An empty list if no matching nodes are found. - """ - if self._driver is None: - raise RuntimeError( - "Memgraph driver is not initialized. Call 'await initialize()' first." - ) - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (n:`{workspace_label}`) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) - RETURN DISTINCT n - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - nodes = [] - async for record in result: - node = record["n"] - node_dict = dict(node) - node_dict["id"] = node_dict.get("entity_id") - nodes.append(node_dict) - await result.consume() - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all edges that are associated with the given chunk_ids. - - Args: - chunk_ids: List of chunk IDs to find associated edges for - - Returns: - list[dict]: A list of edges, where each edge is a dictionary of its properties. - An empty list if no matching edges are found. - """ - if self._driver is None: - raise RuntimeError( - "Memgraph driver is not initialized. Call 'await initialize()' first." - ) - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) - WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id - // Ensure we only return each unique edge once by ordering the source and target - WITH a, b, r, - CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source, - CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target - RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - edges = [] - async for record in result: - edge_properties = record["properties"] - edge_properties["source"] = record["source"] - edge_properties["target"] = record["target"] - edges.append(edge_properties) - await result.consume() - return edges - async def get_knowledge_graph( self, node_label: str, @@ -1104,6 +1050,7 @@ class MemgraphStorage(BaseGraphStorage): "Memgraph driver is not initialized. Call 'await initialize()' first." ) + result = None try: workspace_label = self._get_workspace_label() async with self._driver.session( @@ -1130,6 +1077,8 @@ class MemgraphStorage(BaseGraphStorage): return labels except Exception as e: logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}") + if result is not None: + await result.consume() return [] async def search_labels(self, query: str, limit: int = 50) -> list[str]: @@ -1152,6 +1101,7 @@ class MemgraphStorage(BaseGraphStorage): if not query_lower: return [] + result = None try: workspace_label = self._get_workspace_label() async with self._driver.session( @@ -1185,4 +1135,6 @@ class MemgraphStorage(BaseGraphStorage): return labels except Exception as e: logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") + if result is not None: + await result.consume() return []