From c0a3638d011ab2b3df586fbc0aaf970d37aeee2c Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 15:35:20 +0200 Subject: [PATCH] fix memgraph_impl.py according to test_graph_storage.py --- lightrag/kg/memgraph_impl.py | 96 +++++++++++++++++++----------------- tests/test_graph_storage.py | 1 + 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 36f0186b..41a1129b 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -594,8 +594,8 @@ class MemgraphStorage(BaseGraphStorage): node_dict = dict(node) node_dict["id"] = node_dict.get("entity_id") nodes.append(node_dict) - await result.consume() - return nodes + await result.consume() + return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: """Get all edges that are associated with the given chunk_ids. @@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage): UNWIND $chunk_ids AS chunk_id MATCH (a:base)-[r]-(b:base) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) - RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id + // Ensure we only return each unique edge once by ordering the source and target + WITH a, b, r, + CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source, + CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target + RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties """ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) edges = [] @@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - try: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: if node_label == "*": count_query = "MATCH (n) RETURN count(n) as total" count_result = None @@ -736,45 +741,46 @@ class MemgraphStorage(BaseGraphStorage): logger.debug(f"No record found for node {node_label}") return result - for node_info in record["node_info"]: - node = node_info["node"] - node_id = node.id - if node_id not in seen_nodes: - seen_nodes.add(node_id) - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=[node.get("entity_id")], - properties=dict(node), - ) - ) - - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) - ) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - - return result - finally: if result_set: await result_set.consume() - except Exception as e: - logger.error(f"Error getting knowledge graph: {str(e)}") - return result + if record: + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + return result + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + return result diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 64e66f48..3fd1abbc 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -9,6 +9,7 @@ - NetworkXStorage - Neo4JStorage - PGGraphStorage +- MemgraphStorage """ import asyncio