From bcd326518d8a1c0c565ff4469847b70f2d95d054 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 6 Feb 2025 11:22:17 +0100 Subject: [PATCH] feat: implements graph visualization method for cognee (#493) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR contains the improvement of the visualization endpoint ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **New Features** - Launched an enhanced interactive network visualization utility that renders dynamic, browser-based graphs. The new feature simplifies execution by directly generating an HTML file showcasing the visualization—complete with interactive elements and an on-screen confirmation—providing a more intuitive and efficient experience. --- cognee/__init__.py | 4 +- cognee/api/v1/visualize/visualize.py | 23 ++- cognee/modules/visualization/__init__.py | 0 .../cognee_network_visualization.py | 180 ++++++++++++++++++ cognee/shared/utils.py | 146 +------------- .../visualization/visualization_test.py | 33 ++++ .../tests/unit/processing/utils/utils_test.py | 33 ---- 7 files changed, 236 insertions(+), 183 deletions(-) create mode 100644 cognee/modules/visualization/__init__.py create mode 100644 cognee/modules/visualization/cognee_network_visualization.py create mode 100644 cognee/tests/unit/modules/visualization/visualization_test.py diff --git a/cognee/__init__.py b/cognee/__init__.py index 241dd1ad6..420a7d88a 100644 --- a/cognee/__init__.py +++ b/cognee/__init__.py @@ -5,7 +5,9 @@ from .api.v1.datasets.datasets import datasets from .api.v1.prune import prune from .api.v1.search import SearchType, get_search_history, search from .api.v1.visualize import visualize_graph -from .shared.utils import create_cognee_style_network_with_logo +from cognee.modules.visualization.cognee_network_visualization import ( + cognee_network_visualization, +) # Pipelines from .modules import pipelines diff --git a/cognee/api/v1/visualize/visualize.py b/cognee/api/v1/visualize/visualize.py index 4c4c613c8..a8cb70491 100644 --- a/cognee/api/v1/visualize/visualize.py +++ b/cognee/api/v1/visualize/visualize.py @@ -1,15 +1,30 @@ -from cognee.shared.utils import create_cognee_style_network_with_logo, graph_to_tuple +from cognee.modules.visualization.cognee_network_visualization import ( + cognee_network_visualization, +) from cognee.infrastructure.databases.graph import get_graph_engine import logging -async def visualize_graph(label: str = "name"): - """ """ +import asyncio +from cognee.shared.utils import setup_logging + + +async def visualize_graph(): graph_engine = await get_graph_engine() graph_data = await graph_engine.get_graph_data() logging.info(graph_data) - graph = await create_cognee_style_network_with_logo(graph_data, label=label) + graph = await cognee_network_visualization(graph_data) logging.info("The HTML file has been stored on your home directory! Navigate there with cd ~") return graph + + +if __name__ == "__main__": + setup_logging(logging.ERROR) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(visualize_graph()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/cognee/modules/visualization/__init__.py b/cognee/modules/visualization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/modules/visualization/cognee_network_visualization.py b/cognee/modules/visualization/cognee_network_visualization.py new file mode 100644 index 000000000..df06b2094 --- /dev/null +++ b/cognee/modules/visualization/cognee_network_visualization.py @@ -0,0 +1,180 @@ +import networkx as nx +import json +import os + + +async def cognee_network_visualization(graph_data): + nodes_data, edges_data = graph_data + + G = nx.DiGraph() + + nodes_list = [] + color_map = { + "Entity": "#f47710", + "EntityType": "#6510f4", + "DocumentChunk": "#801212", + "default": "#D3D3D3", + } + + for node_id, node_info in nodes_data: + node_info = node_info.copy() + node_info["id"] = str(node_id) + node_info["color"] = color_map.get(node_info.get("pydantic_type", "default"), "#D3D3D3") + node_info["name"] = node_info.get("name", str(node_id)) + del node_info[ + "updated_at" + ] #:TODO: We should decide what properties to show on the nodes and edges, we dont necessarily need all. + del node_info["created_at"] + nodes_list.append(node_info) + G.add_node(node_id, **node_info) + + edge_labels = {} + links_list = [] + for source, target, relation, edge_info in edges_data: + source = str(source) + target = str(target) + G.add_edge(source, target) + edge_labels[(source, target)] = relation + links_list.append({"source": source, "target": target, "relation": relation}) + + html_template = """ + + +
+ + + + + + + + + + + + """ + + html_content = html_template.replace("{nodes}", json.dumps(nodes_list)) + html_content = html_content.replace("{links}", json.dumps(links_list)) + + home_dir = os.path.expanduser("~") + output_file = os.path.join(home_dir, "graph_visualization.html") + + with open(output_file, "w") as f: + f.write(html_content) + + print(f"Graph visualization saved as {output_file}") + + return html_content diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index 4e5523fd2..c8064efdf 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt import logging import sys +import json from cognee.base_config import get_base_config from cognee.infrastructure.databases.graph import get_graph_engine @@ -336,86 +337,6 @@ def style_and_render_graph(p, G, layout_positions, node_attribute, node_colors, return graph_renderer -async def create_cognee_style_network_with_logo( - G, - output_filename="cognee_network_with_logo.html", - title="Cognee-Style Network", - label="group", - layout_func=nx.spring_layout, - layout_scale=3.0, - logo_alpha=0.1, - bokeh_object=False, -): - """ - Create a Cognee-inspired network visualization with an embedded logo. - """ - from bokeh.plotting import figure, from_networkx - from bokeh.models import Circle, MultiLine, HoverTool, ColumnDataSource, Range1d - from bokeh.plotting import output_file, show - - from bokeh.embed import file_html - from bokeh.resources import CDN - from bokeh.io import export_png - - logging.info("Converting graph to serializable format...") - G = await convert_to_serializable_graph(G) - - logging.info("Generating layout positions...") - layout_positions = generate_layout_positions(G, layout_func, layout_scale) - - logging.info("Assigning node colors...") - palette = ["#6510F4", "#0DFF00", "#FFFFFF"] - node_colors, color_map = assign_node_colors(G, label, palette) - - logging.info("Calculating centrality...") - centrality = nx.degree_centrality(G) - - logging.info("Preparing Bokeh output...") - output_file(output_filename) - p = figure( - title=title, - tools="pan,wheel_zoom,save,reset,hover", - active_scroll="wheel_zoom", - width=1200, - height=900, - background_fill_color="#F4F4F4", - x_range=Range1d(-layout_scale, layout_scale), - y_range=Range1d(-layout_scale, layout_scale), - ) - p.toolbar.logo = None - p.axis.visible = False - p.grid.visible = False - - logging.info("Embedding logo into visualization...") - embed_logo(p, layout_scale, logo_alpha, "bottom_right") - embed_logo(p, layout_scale, logo_alpha, "top_left") - - logging.info("Styling and rendering graph...") - style_and_render_graph(p, G, layout_positions, label, node_colors, centrality) - - logging.info("Adding hover tool...") - hover_tool = HoverTool( - tooltips=[ - ("Node", "@index"), - (label.capitalize(), f"@{label}"), - ("Centrality", "@radius{0.00}"), - ], - ) - p.add_tools(hover_tool) - - logging.info(f"Saving visualization to {output_filename}...") - html_content = file_html(p, CDN, title) - - home_dir = os.path.expanduser("~") - - # Construct the final output file path - output_filepath = os.path.join(home_dir, output_filename) - with open(output_filepath, "w") as f: - f.write(html_content) - - return html_content - - def graph_to_tuple(graph): """ Converts a networkx graph to a tuple of (nodes, edges). @@ -443,68 +364,3 @@ def setup_logging(log_level=logging.INFO): root_logger.addHandler(stream_handler) root_logger.setLevel(log_level) - - -# ---------------- Example Usage ---------------- -if __name__ == "__main__": - import networkx as nx - - # Create a sample graph - nodes = [ - (1, {"group": "A"}), - (2, {"group": "A"}), - (3, {"group": "B"}), - (4, {"group": "B"}), - (5, {"group": "C"}), - ] - edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 1)] - - # Create a NetworkX graph - G = nx.Graph() - G.add_nodes_from(nodes) - G.add_edges_from(edges) - - G = graph_to_tuple(G) - - import asyncio - - output_html = asyncio.run( - create_cognee_style_network_with_logo( - G, - output_filename="example_network.html", - title="Example Cognee Network", - label="group", # Attribute to use for coloring nodes - layout_func=nx.spring_layout, # Layout function - layout_scale=3.0, # Scale for the layout - logo_alpha=0.2, - ) - ) - - # Call the function - # output_html = await create_cognee_style_network_with_logo( - # G=G, - # output_filename="example_network.html", - # title="Example Cognee Network", - # node_attribute="group", # Attribute to use for coloring nodes - # layout_func=nx.spring_layout, # Layout function - # layout_scale=3.0, # Scale for the layout - # logo_alpha=0.2, # Transparency of the logo - # ) - - # Print the output filename - print("Network visualization saved as example_network.html") - -# # Create a random geometric graph -# G = nx.random_geometric_graph(50, 0.3) -# # Assign random group attributes for coloring -# for i, node in enumerate(G.nodes()): -# G.nodes[node]["group"] = f"Group {i % 3 + 1}" -# -# create_cognee_graph( -# G, -# output_filename="cognee_style_network_with_logo.html", -# title="Cognee-Graph Network", -# node_attribute="group", -# layout_func=nx.spring_layout, -# layout_scale=3.0, # Replace with your logo file path -# ) diff --git a/cognee/tests/unit/modules/visualization/visualization_test.py b/cognee/tests/unit/modules/visualization/visualization_test.py new file mode 100644 index 000000000..06efb6234 --- /dev/null +++ b/cognee/tests/unit/modules/visualization/visualization_test.py @@ -0,0 +1,33 @@ +import pytest +from cognee.modules.visualization.cognee_network_visualization import ( + cognee_network_visualization, +) + + +@pytest.mark.asyncio +async def test_create_cognee_style_network_with_logo(): + nodes_data = [ + (1, {"pydantic_type": "Entity", "name": "Node1", "updated_at": 123, "created_at": 123}), + ( + 2, + { + "pydantic_type": "DocumentChunk", + "name": "Node2", + "updated_at": 123, + "created_at": 123, + }, + ), + ] + edges_data = [ + (1, 2, "related_to", {}), + ] + graph_data = (nodes_data, edges_data) + + html_output = await cognee_network_visualization(graph_data) + + assert isinstance(html_output, str) + + assert "" in html_output + assert '' in html_output + assert "var nodes =" in html_output + assert "var links =" in html_output diff --git a/cognee/tests/unit/processing/utils/utils_test.py b/cognee/tests/unit/processing/utils/utils_test.py index 067ab6ea7..cfdae0f34 100644 --- a/cognee/tests/unit/processing/utils/utils_test.py +++ b/cognee/tests/unit/processing/utils/utils_test.py @@ -5,17 +5,12 @@ import pandas as pd from unittest.mock import patch, mock_open from io import BytesIO from uuid import uuid4 -from datetime import datetime, timezone -from cognee.shared.exceptions import IngestionError from cognee.shared.utils import ( get_anonymous_id, - send_telemetry, get_file_content_hash, prepare_edges, prepare_nodes, - create_cognee_style_network_with_logo, - graph_to_tuple, ) @@ -78,31 +73,3 @@ def test_prepare_nodes(): assert isinstance(nodes_df, pd.DataFrame) assert len(nodes_df) == 1 - - -@pytest.mark.asyncio -async def test_create_cognee_style_network_with_logo(): - import networkx as nx - from unittest.mock import patch - from io import BytesIO - - # Create a sample graph - graph = nx.Graph() - graph.add_node(1, group="A") - graph.add_node(2, group="B") - graph.add_edge(1, 2) - - # Convert the graph to a tuple format for serialization - graph_tuple = graph_to_tuple(graph) - - result = await create_cognee_style_network_with_logo( - graph_tuple, - title="Test Network", - layout_func=nx.spring_layout, - layout_scale=3.0, - logo_alpha=0.5, - ) - - assert result is not None - assert isinstance(result, str) - assert len(result) > 0