From bd158d096bb26215b283496b4bf2aebe4d0e2292 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 14:47:23 +0200 Subject: [PATCH] polish Memgraph implementation --- lightrag/kg/memgraph_impl.py | 803 ++++++++++++++++++++++++----------- 1 file changed, 551 insertions(+), 252 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index df28b8b2..bf870154 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -89,183 +89,419 @@ class MemgraphStorage(BaseGraphStorage): pass async def has_node(self, node_id: str) -> bool: + """ + Check if a node exists in the graph. + + Args: + node_id: The ID of the node to check. + + Returns: + bool: True if the node exists, False otherwise. + + Raises: + Exception: If there is an error checking the node existence. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" - result = await session.run(query, entity_id=node_id) - single_result = await result.single() - await result.consume() - return single_result["node_exists"] + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["node_exists"] + except Exception as e: + logger.error(f"Error checking node existence for {node_id}: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """ + Check if an edge exists between two nodes in the graph. + + Args: + source_node_id: The ID of the source node. + target_node_id: The ID of the target node. + + Returns: + bool: True if the edge exists, False otherwise. + + Raises: + Exception: If there is an error checking the edge existence. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " - "RETURN COUNT(r) > 0 AS edgeExists" - ) - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) - single_result = await result.single() - await result.consume() - return single_result["edgeExists"] + try: + query = ( + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["edgeExists"] + except Exception as e: + logger.error( + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" - result = await session.run(query, entity_id=node_id) - records = await result.fetch(2) - await result.consume() - if records: - node = records[0]["n"] - node_dict = dict(node) - if "labels" in node_dict: - node_dict["labels"] = [ - label for label in node_dict["labels"] if label != "base" - ] - return node_dict - return None + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) + try: + records = await result.fetch( + 2 + ) # Get 2 records for duplication check + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{node_id}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + # Remove base label from labels list if it exists + if "labels" in node_dict: + node_dict["labels"] = [ + label + for label in node_dict["labels"] + if label != "base" + ] + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {node_id}: {str(e)}") + raise + + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + + Raises: + Exception: If there is an error executing the query + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + try: + record = await result.single() + + if not record: + logger.warning(f"No node found with label '{node_id}'") + return 0 + + degree = record["degree"] + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node degree for {node_id}: {str(e)}") + raise async def get_all_labels(self) -> list[str]: + """ + Get all existing node labels in the database + Returns: + ["Person", "Company", ...] # Alphabetically sorted label list + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (n:base) - WHERE n.entity_id IS NOT NULL - RETURN DISTINCT n.entity_id AS label - ORDER BY label - """ - result = await session.run(query) - labels = [] - async for record in result: - labels.append(record["label"]) - await result.consume() - return labels + try: + query = """ + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-(connected:base) - WHERE connected.entity_id IS NOT NULL - RETURN n, r, connected - """ - results = await session.run(query, entity_id=source_node_id) - edges = [] - async for record in results: - source_node = record["n"] - connected_node = record["connected"] - if not source_node or not connected_node: - continue - source_label = source_node.get("entity_id") - target_label = connected_node.get("entity_id") - if source_label and target_label: - edges.append((source_label, target_label)) - await results.consume() - return edges + """Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + Exception: If there is an error executing the query + """ + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + results = await session.run(query, entity_id=source_node_id) + + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + # Skip if either node is None + if not source_node or not connected_node: + continue + + source_label = ( + source_node.get("entity_id") + if source_node.get("entity_id") + else None + ) + target_label = ( + connected_node.get("entity_id") + if connected_node.get("entity_id") + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges + except Exception as e: + logger.error( + f"Error getting edges for node {source_node_id}: {str(e)}" + ) + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) - RETURN properties(r) as edge_properties - """ - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) - records = await result.fetch(2) - await result.consume() - if records: - edge_result = dict(records[0]["edge_properties"]) - for key, default_value in { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - }.items(): - if key not in edge_result: - edge_result[key] = default_value - return edge_result - return None + try: + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" + ) + return edge_result + return None + except Exception as e: + logger.error( + f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Upsert a node in the Neo4j database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ properties = node_data - entity_type = properties.get("entity_type", "base") + entity_type = properties["entity_type"] if "entity_id" not in properties: - raise ValueError( - "Memgraph: node properties must contain an 'entity_id' field" - ) - async with self._driver.session(database=self._DATABASE) as session: + raise ValueError("Neo4j: node properties must contain an 'entity_id' field") - async def execute_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:base {{entity_id: $entity_id}}) + try: + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + query = ( + """ + MERGE (n:base {entity_id: $entity_id}) SET n += $properties - SET n:`{entity_type}` + SET n:`%s` """ - result = await tx.run(query, entity_id=node_id, properties=properties) - await result.consume() + % entity_type + ) + result = await tx.run( + query, entity_id=node_id, properties=properties + ) + await result.consume() # Ensure result is fully consumed - await session.execute_write(execute_upsert) + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during upsert: {str(e)}") + raise async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - edge_properties = edge_data - async with self._driver.session(database=self._DATABASE) as session: + """ + Upsert an edge and its properties between two nodes identified by their labels. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. - async def execute_upsert(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id}) - WITH source - MATCH (target:base {entity_id: $target_entity_id}) - MERGE (source)-[r:DIRECTED]-(target) - SET r += $properties - RETURN r, source, target - """ - result = await tx.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - properties=edge_properties, - ) - await result.consume() + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge - await session.execute_write(execute_upsert) + Raises: + Exception: If there is an error executing the query + """ + try: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id}) + WITH source + MATCH (target:base {entity_id: $target_entity_id}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + try: + await result.fetch(2) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during edge upsert: {str(e)}") + raise async def delete_node(self, node_id: str) -> None: + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + + Raises: + Exception: If there is an error executing the query + """ + async def _do_delete(tx: AsyncManagedTransaction): query = """ MATCH (n:base {entity_id: $entity_id}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) + logger.debug(f"Deleted node with label {node_id}") await result.consume() - async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_delete) + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ for node in nodes: await self.delete_node(node) async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + + Raises: + Exception: If there is an error executing the query + """ for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): @@ -276,15 +512,32 @@ class MemgraphStorage(BaseGraphStorage): result = await tx.run( query, source_entity_id=source, target_entity_id=target ) - await result.consume() + logger.debug(f"Deleted edge from '{source}' to '{target}'") + await result.consume() # Ensure result is fully consumed - async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_delete_edge) + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def drop(self) -> dict[str, str]: + """Drop all data from storage and clean up resources + + This method will delete all nodes and relationships in the Neo4j database. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + + Raises: + Exception: If there is an error executing the query + """ try: async with self._driver.session(database=self._DATABASE) as session: - query = "MATCH (n) DETACH DELETE n" + query = "DROP GRAPH" result = await session.run(query) await result.consume() logger.info( @@ -295,30 +548,36 @@ class MemgraphStorage(BaseGraphStorage): logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} - async def node_degree(self, node_id: str) -> int: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-() - RETURN COUNT(r) AS degree - """ - result = await session.run(query, entity_id=node_id) - record = await result.single() - await result.consume() - if not record: - return 0 - return record["degree"] - async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree - return int(src_degree) + int(trg_degree) + + degrees = int(src_degree) + int(trg_degree) + return degrees async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated nodes for + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -335,10 +594,19 @@ class MemgraphStorage(BaseGraphStorage): node_dict = dict(node) node_dict["id"] = node_dict.get("entity_id") nodes.append(node_dict) - await result.consume() - return nodes + await result.consume() + return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated edges for + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -364,118 +632,149 @@ class MemgraphStorage(BaseGraphStorage): max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + + Raises: + Exception: If there is an error executing the query + """ result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - if node_label == "*": - count_query = "MATCH (n) RETURN count(n) as total" - count_result = await session.run(count_query) - count_record = await count_result.single() - await count_result.consume() - if count_record and count_record["total"] > max_nodes: - result.is_truncated = True - logger.info( - f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" - ) - main_query = """ - MATCH (n) - OPTIONAL MATCH (n)-[r]-() - WITH n, COALESCE(count(r), 0) AS degree - ORDER BY degree DESC - LIMIT $max_nodes - WITH collect({node: n}) AS filtered_nodes - UNWIND filtered_nodes AS node_info - WITH collect(node_info.node) AS kept_nodes, filtered_nodes - OPTIONAL MATCH (a)-[r]-(b) - WHERE a IN kept_nodes AND b IN kept_nodes - RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships - """ - result_set = await session.run(main_query, {"max_nodes": max_nodes}) - record = await result_set.single() - await result_set.consume() - else: - # BFS fallback for Memgraph (no APOC) - from collections import deque + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + if node_label == "*": + count_query = "MATCH (n) RETURN count(n) as total" + count_result = None + try: + count_result = await session.run(count_query) + count_record = await count_result.single() + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() - # Get the starting node - start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" - node_result = await session.run(start_query, entity_id=node_label) - node_record = await node_result.single() - await node_result.consume() - if not node_record: - return result - start_node = node_record["n"] - queue = deque([(start_node, 0)]) - visited = set() - bfs_nodes = [] - while queue and len(bfs_nodes) < max_nodes: - current_node, depth = queue.popleft() - node_id = current_node.get("entity_id") - if node_id in visited: - continue - visited.add(node_id) - bfs_nodes.append(current_node) - if depth < max_depth: - # Get neighbors - neighbor_query = """ - MATCH (n:base {entity_id: $entity_id})-[]-(m:base) - RETURN m - """ - neighbors_result = await session.run( - neighbor_query, entity_id=node_id - ) - neighbors = [ - rec["m"] for rec in await neighbors_result.to_list() - ] - await neighbors_result.consume() - for neighbor in neighbors: - neighbor_id = neighbor.get("entity_id") - if neighbor_id not in visited: - queue.append((neighbor, depth + 1)) - # Build subgraph - subgraph_ids = [n.get("entity_id") for n in bfs_nodes] - # Nodes - for n in bfs_nodes: - node_id = n.get("entity_id") - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties=dict(n), - ) - ) - seen_nodes.add(node_id) - # Edges - if subgraph_ids: - edge_query = """ - MATCH (a:base)-[r]-(b:base) - WHERE a.entity_id IN $ids AND b.entity_id IN $ids - RETURN DISTINCT r, a, b + # Run the main query to get nodes with highest degree + main_query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + OPTIONAL MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ - edge_result = await session.run(edge_query, ids=subgraph_ids) - async for record in edge_result: - r = record["r"] - a = record["a"] - b = record["b"] - edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=a.get("entity_id"), - target=b.get("entity_id"), - properties=dict(r), + result_set = None + try: + result_set = await session.run( + main_query, {"max_nodes": max_nodes} + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() + + else: + bfs_query = """ + MATCH (start) WHERE start.entity_id = $entity_id + WITH start + CALL { + WITH start + MATCH path = (start)-[*0..$max_depth]-(node) + WITH nodes(path) AS path_nodes, relationships(path) AS path_rels + UNWIND path_nodes AS n + WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists + WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels + RETURN all_nodes, all_rels + } + WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes + + // Apply node limiting here + WITH CASE + WHEN total_nodes <= $max_nodes THEN nodes + ELSE nodes[0..$max_nodes] + END AS limited_nodes, + relationships, + total_nodes, + total_nodes > $max_nodes AS is_truncated + UNWIND limited_nodes AS node + WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated + RETURN node_info, relationships, total_nodes, is_truncated + """ + result_set = None + try: + result_set = await session.run( + bfs_query, + { + "entity_id": node_label, + "max_depth": max_depth, + "max_nodes": max_nodes, + }, + ) + record = await result_set.single() + if not record: + logger.debug(f"No record found for node {node_label}") + return result + + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) ) - ) - seen_edges.add(edge_id) - await edge_result.consume() - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - return result + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + return result + + finally: + if result_set: + await result_set.consume() + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + return result