From 15c4bac87f975e5192592e8959f1523daec52cc1 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 11 Jul 2025 17:02:42 +0200 Subject: [PATCH] wip fix Memgraph get_knowledge_graph issue by using mage function --- lightrag/kg/memgraph_impl.py | 284 ++++++++++++++++++++++++++--------- 1 file changed, 210 insertions(+), 74 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index d0044499..67c8a63f 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -742,7 +742,7 @@ class MemgraphStorage(BaseGraphStorage): Args: node_label: Label of the starting node, * means all nodes max_depth: Maximum depth of the subgraph, Defaults to 3 - max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + max_nodes: Maximum nodes to return by BFS, Defaults to 1000 Returns: KnowledgeGraph object containing nodes and edges, with an is_truncated flag @@ -796,7 +796,7 @@ class MemgraphStorage(BaseGraphStorage): OPTIONAL MATCH (a)-[r]-(b) WHERE a IN kept_nodes AND b IN kept_nodes RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships + collect(DISTINCT r) AS relationships """ result_set = None try: @@ -810,99 +810,64 @@ class MemgraphStorage(BaseGraphStorage): await result_set.consume() else: - # return await self._robust_fallback(node_label, max_depth, max_nodes) - # First try without limit to check if we need to truncate - full_query = f""" + # For specific node queries, use path.subgraph_all with the refined query pattern + subgraph_query = f""" MATCH (start:`{workspace_label}`) WHERE start.entity_id = $entity_id WITH start - MATCH path = (start)-[*BFS ..{max_depth}]-(node:`{workspace_label}`) - WITH nodes(path) AS path_nodes, relationships(path) AS path_rels - UNWIND path_nodes AS n - WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists - WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels + CALL path.subgraph_all(start, {{ + relationshipFilter: [], + labelFilter: ['{workspace_label}'], + minHops: 0, + maxHops: $max_depth + }}) + YIELD nodes, rels + WITH + CASE + WHEN size(nodes) <= $max_nodes THEN nodes + ELSE nodes[0..$max_nodes] + END AS limited_nodes, + rels, + size(nodes) > $max_nodes AS is_truncated + UNWIND rels AS rel + WITH limited_nodes, rel, is_truncated + WHERE startNode(rel) IN limited_nodes AND endNode(rel) IN limited_nodes + WITH limited_nodes, collect(DISTINCT rel) AS limited_relationships, is_truncated RETURN - [node IN all_nodes | {{node: node}}] AS node_info, - all_rels AS relationships, - size(all_nodes) AS total_nodes + [node IN limited_nodes | {{node: node}}] AS node_info, + limited_relationships AS relationships, + is_truncated """ - # Try to get full result - full_result = None + result_set = None try: - full_result = await session.run( - full_query, + result_set = await session.run( + subgraph_query, { "entity_id": node_label, + "max_depth": max_depth, + "max_nodes": max_nodes, }, ) - full_record = await full_result.single() + record = await result_set.single() # If no record found, return empty KnowledgeGraph - if not full_record: + if not record: logger.debug(f"No nodes found for entity_id: {node_label}") return result - # If record found, check node count - total_nodes = full_record["total_nodes"] - - if total_nodes <= max_nodes: - # If node count is within limit, use full result directly - logger.debug( - f"Using full result with {total_nodes} nodes (no truncation needed)" - ) - record = full_record - else: - # If node count exceeds limit, set truncated flag and run limited query + # Check if the result was truncated + if record.get("is_truncated"): result.is_truncated = True logger.info( - f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" + f"Graph truncated: breadth-first search limited to {max_nodes} nodes" ) - # Run limited query - limited_query = f""" - MATCH (start:`{workspace_label}`) - WHERE start.entity_id = $entity_id - WITH start - MATCH path = (start)-[*BFS ..{max_depth}]-(node:`{workspace_label}`) - WITH nodes(path) AS path_nodes, relationships(path) AS path_rels - UNWIND path_nodes AS n - WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists - WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels - WITH - CASE - WHEN size(all_nodes) <= $max_nodes THEN all_nodes - ELSE all_nodes[0..$max_nodes] - END AS limited_nodes, - all_rels - UNWIND all_rels AS rel - WITH limited_nodes, rel - WHERE startNode(rel) IN limited_nodes AND endNode(rel) IN limited_nodes - WITH limited_nodes, collect(DISTINCT rel) AS limited_relationships - RETURN - [node IN limited_nodes | {{node: node}}] AS node_info, - limited_relationships AS relationships - """ - - result_set = None - try: - result_set = await session.run( - limited_query, - { - "entity_id": node_label, - "max_nodes": max_nodes, - }, - ) - record = await result_set.single() - finally: - if result_set: - await result_set.consume() finally: - if full_result: - await full_result.consume() + if result_set: + await result_set.consume() if record: - # Handle nodes (compatible with multi-label cases) for node_info in record["node_info"]: node = node_info["node"] node_id = node.id @@ -916,7 +881,6 @@ class MemgraphStorage(BaseGraphStorage): ) seen_nodes.add(node_id) - # Handle relationships (including direction information) for rel in record["relationships"]: edge_id = rel.id if edge_id not in seen_edges: @@ -938,6 +902,178 @@ class MemgraphStorage(BaseGraphStorage): ) except Exception as e: - logger.error(f"Error during subgraph query for {node_label}: {str(e)}") + logger.warning(f"Memgraph error during subgraph query: {str(e)}") + if node_label != "*": + logger.warning( + "Memgraph: falling back to basic Cypher recursive search..." + ) + return await self._robust_fallback(node_label, max_depth, max_nodes) + else: + logger.warning( + "Memgraph: Mage plugin error with wildcard query, returning empty result" + ) return result + + async def _robust_fallback( + self, node_label: str, max_depth: int, max_nodes: int + ) -> KnowledgeGraph: + """ + Fallback implementation when MAGE plugin is not available or incompatible. + This method implements the same functionality as get_knowledge_graph but uses + only basic Cypher queries and true breadth-first traversal instead of MAGE procedures. + """ + from collections import deque + + result = KnowledgeGraph() + visited_nodes = set() + visited_edges = set() + visited_edge_pairs = set() + + # Get the starting node's data + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + RETURN id(n) as node_id, n + """ + node_result = await session.run(query, entity_id=node_label) + try: + node_record = await node_result.single() + if not node_record: + return result + + # Create initial KnowledgeGraphNode + start_node = KnowledgeGraphNode( + id=f"{node_record['n'].get('entity_id')}", + labels=[node_record["n"].get("entity_id")], + properties=dict(node_record["n"]._properties), + ) + finally: + await node_result.consume() # Ensure results are consumed + + # Initialize queue for BFS with (node, depth) tuples + queue = deque([(start_node, 0)]) + + # Keep track of all nodes we've discovered (including those we might not add due to limits) + discovered_nodes = {} # node_id -> KnowledgeGraphNode + discovered_nodes[start_node.id] = start_node + + # True BFS implementation using a queue + while queue: + # Dequeue the next node to process + current_node, current_depth = queue.popleft() + + # Skip if already processed or exceeds max depth + if current_node.id in visited_nodes: + continue + + if current_depth > max_depth: + logger.debug( + f"Skipping node at depth {current_depth} (max_depth: {max_depth})" + ) + continue + + # Check if we've reached the node limit + if len(visited_nodes) >= max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" + ) + break + + # Add current node to result + result.nodes.append(current_node) + visited_nodes.add(current_node.id) + + # Only continue exploring if we haven't reached max depth + if current_depth < max_depth: + # Get all edges and target nodes for the current node + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (a:`{workspace_label}` {{entity_id: $entity_id}})-[r]-(b:`{workspace_label}`) + WHERE b.entity_id IS NOT NULL + RETURN r, b, id(r) as edge_id + """ + results = await session.run(query, entity_id=current_node.id) + + # Get all records and release database connection + records = await results.fetch( + 1000 + ) # Max neighbor nodes we can handle + await results.consume() # Ensure results are consumed + + # Process all neighbors + for record in records: + rel = record["r"] + edge_id = str(record["edge_id"]) + b_node = record["b"] + target_id = b_node.get("entity_id") + + if target_id and edge_id not in visited_edges: + # Create KnowledgeGraphNode for target if not already discovered + if target_id not in discovered_nodes: + target_node = KnowledgeGraphNode( + id=f"{target_id}", + labels=[target_id], + properties=dict(b_node._properties), + ) + discovered_nodes[target_id] = target_node + + # Add to queue for further exploration + queue.append((target_node, current_depth + 1)) + + # Second pass: Add edges only between nodes that are actually in the result + final_node_ids = {node.id for node in result.nodes} + + # Now collect all edges between the nodes we actually included + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + # Use a parameterized query to get all edges between our final nodes + query = f""" + UNWIND $node_ids AS node_id + MATCH (a:`{workspace_label}` {{entity_id: node_id}})-[r]-(b:`{workspace_label}`) + WHERE b.entity_id IN $node_ids + RETURN DISTINCT r, a.entity_id AS source_id, b.entity_id AS target_id, id(r) AS edge_id + """ + results = await session.run(query, node_ids=list(final_node_ids)) + + edges_to_add = [] + async for record in results: + rel = record["r"] + edge_id = str(record["edge_id"]) + source_id = record["source_id"] + target_id = record["target_id"] + + if edge_id not in visited_edges: + # Create edge pair for deduplication (undirected) + sorted_pair = tuple(sorted([source_id, target_id])) + + if sorted_pair not in visited_edge_pairs: + edges_to_add.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{source_id}", + target=f"{target_id}", + properties=dict(rel), + ) + ) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) + + await results.consume() + + # Add all valid edges to the result + result.edges.extend(edges_to_add) + + logger.info( + f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result