feat: implements graph visualization method for cognee (#493)

<!-- .github/pull_request_template.md -->

## Description
This PR contains the improvement of the visualization endpoint

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Launched an enhanced interactive network visualization utility that
renders dynamic, browser-based graphs. The new feature simplifies
execution by directly generating an HTML file showcasing the
visualization—complete with interactive elements and an on-screen
confirmation—providing a more intuitive and efficient experience.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
hajdul88 2025-02-06 11:22:17 +01:00 committed by GitHub
parent d56fd8d925
commit bcd326518d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 236 additions and 183 deletions

View file

@ -5,7 +5,9 @@ from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune
from .api.v1.search import SearchType, get_search_history, search
from .api.v1.visualize import visualize_graph
from .shared.utils import create_cognee_style_network_with_logo
from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization,
)
# Pipelines
from .modules import pipelines

View file

@ -1,15 +1,30 @@
from cognee.shared.utils import create_cognee_style_network_with_logo, graph_to_tuple
from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization,
)
from cognee.infrastructure.databases.graph import get_graph_engine
import logging
async def visualize_graph(label: str = "name"):
""" """
import asyncio
from cognee.shared.utils import setup_logging
async def visualize_graph():
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
logging.info(graph_data)
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
graph = await cognee_network_visualization(graph_data)
logging.info("The HTML file has been stored on your home directory! Navigate there with cd ~")
return graph
if __name__ == "__main__":
setup_logging(logging.ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(visualize_graph())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

View file

@ -0,0 +1,180 @@
import networkx as nx
import json
import os
async def cognee_network_visualization(graph_data):
nodes_data, edges_data = graph_data
G = nx.DiGraph()
nodes_list = []
color_map = {
"Entity": "#f47710",
"EntityType": "#6510f4",
"DocumentChunk": "#801212",
"default": "#D3D3D3",
}
for node_id, node_info in nodes_data:
node_info = node_info.copy()
node_info["id"] = str(node_id)
node_info["color"] = color_map.get(node_info.get("pydantic_type", "default"), "#D3D3D3")
node_info["name"] = node_info.get("name", str(node_id))
del node_info[
"updated_at"
] #:TODO: We should decide what properties to show on the nodes and edges, we dont necessarily need all.
del node_info["created_at"]
nodes_list.append(node_info)
G.add_node(node_id, **node_info)
edge_labels = {}
links_list = []
for source, target, relation, edge_info in edges_data:
source = str(source)
target = str(target)
G.add_edge(source, target)
edge_labels[(source, target)] = relation
links_list.append({"source": source, "target": target, "relation": relation})
html_template = """
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script src="https://d3js.org/d3.v5.min.js"></script>
<style>
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', sans-serif; }
svg { width: 100vw; height: 100vh; display: block; }
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
</style>
</head>
<body>
<svg></svg>
<script>
var nodes = {nodes};
var links = {links};
var svg = d3.select("svg"),
width = window.innerWidth,
height = window.innerHeight;
var container = svg.append("g");
var simulation = d3.forceSimulation(nodes)
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
.force("charge", d3.forceManyBody().strength(-275))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("x", d3.forceX().strength(0.1).x(width / 2))
.force("y", d3.forceY().strength(0.1).y(height / 2));
var link = container.append("g")
.attr("class", "links")
.selectAll("line")
.data(links)
.enter().append("line")
.attr("stroke-width", 2);
var edgeLabels = container.append("g")
.attr("class", "edge-labels")
.selectAll("text")
.data(links)
.enter().append("text")
.attr("class", "edge-label")
.text(d => d.relation);
var nodeGroup = container.append("g")
.attr("class", "nodes")
.selectAll("g")
.data(nodes)
.enter().append("g");
var node = nodeGroup.append("circle")
.attr("r", 13)
.attr("fill", d => d.color)
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("end", dragended));
nodeGroup.append("text")
.attr("class", "node-label")
.attr("dy", 4)
.attr("text-anchor", "middle")
.text(d => d.name);
node.append("title").text(d => JSON.stringify(d));
simulation.on("tick", function() {
link.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
edgeLabels
.attr("x", d => (d.source.x + d.target.x) / 2)
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
node.attr("cx", d => d.x)
.attr("cy", d => d.y);
nodeGroup.select("text")
.attr("x", d => d.x)
.attr("y", d => d.y)
.attr("dy", 4)
.attr("text-anchor", "middle");
});
svg.call(d3.zoom().on("zoom", function() {
container.attr("transform", d3.event.transform);
}));
function dragstarted(d) {
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
}
function dragged(d) {
d.fx = d3.event.x;
d.fy = d3.event.y;
}
function dragended(d) {
if (!d3.event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}
window.addEventListener("resize", function() {
width = window.innerWidth;
height = window.innerHeight;
svg.attr("width", width).attr("height", height);
simulation.force("center", d3.forceCenter(width / 2, height / 2));
simulation.alpha(1).restart();
});
</script>
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.7496 4.92654C7.83308 4.92654 4.8585 7.94279 4.8585 11.3612V14.9304C4.8585 18.3488 7.83308 21.3651 11.7496 21.3651C13.6831 21.3651 15.0217 20.8121 16.9551 19.3543C18.0458 18.5499 19.5331 18.8013 20.3263 19.9072C21.1195 21.0132 20.8717 22.5213 19.781 23.3257C17.3518 25.0851 15.0217 26.2414 11.7 26.2414C5.35425 26.2414 0 21.2646 0 14.9304V11.3612C0 4.97681 5.35425 0.0502739 11.7 0.0502739C15.0217 0.0502739 17.3518 1.2065 19.781 2.96598C20.8717 3.77032 21.1195 5.27843 20.3263 6.38439C19.5331 7.49035 18.0458 7.69144 16.9551 6.93737C15.0217 5.52979 13.6831 4.92654 11.7496 4.92654ZM35.5463 4.92654C31.7289 4.92654 28.6552 8.04333 28.6552 11.8639V14.478C28.6552 18.2986 31.7289 21.4154 35.5463 21.4154C39.3141 21.4154 42.3878 18.2986 42.3878 14.478V11.8639C42.3878 8.04333 39.3141 4.92654 35.5463 4.92654ZM23.7967 11.8639C23.7967 5.32871 29.0518 0 35.5463 0C42.0408 0 47.2463 5.32871 47.2463 11.8639V14.478C47.2463 21.0132 42.0408 26.3419 35.5463 26.3419C29.0518 26.3419 23.7967 21.0635 23.7967 14.478V11.8639ZM63.3091 5.07736C59.4917 5.07736 56.418 8.19415 56.418 12.0147C56.418 15.8353 59.4917 18.9521 63.3091 18.9521C67.1265 18.9521 70.1506 15.8856 70.1506 12.0147C70.1506 8.14388 67.0769 5.07736 63.3091 5.07736ZM51.5595 11.9645C51.5595 5.42925 56.8146 0.150814 63.3091 0.150814C66.0854 0.150814 68.5642 1.10596 70.5968 2.71463L72.4311 0.904876C73.3731 -0.0502693 74.9099 -0.0502693 75.8519 0.904876C76.7938 1.86002 76.7938 3.41841 75.8519 4.37356L73.7201 6.53521C74.5629 8.19414 75.0587 10.0542 75.0587 12.0147C75.0587 18.4997 69.8532 23.8284 63.3587 23.8284C63.3091 23.8284 63.2099 23.8284 63.1603 23.8284H58.0044C57.1616 23.8284 56.4675 24.5322 56.4675 25.3868C56.4675 26.2414 57.1616 26.9452 58.0044 26.9452H64.6476H66.7794C68.5146 26.9452 70.3489 27.4479 71.7866 28.6041C73.2739 29.8106 74.2159 31.5701 74.4142 33.7317C74.7116 37.6026 72.0345 40.2166 69.8532 41.0713L63.8048 43.7859C62.5654 44.3389 61.1277 43.7859 60.6319 42.5291C60.0866 41.2723 60.6319 39.8648 61.8714 39.3118L68.0188 36.5972C68.0684 36.5972 68.118 36.5469 68.1675 36.5469C68.4154 36.4463 68.8616 36.1447 69.2087 35.6923C69.5061 35.2398 69.7044 34.7371 69.6548 34.1339C69.6053 33.229 69.2582 32.7263 68.8616 32.4247C68.4154 32.0728 67.7214 31.8214 66.8786 31.8214H58.2027C58.1531 31.8214 58.1531 31.8214 58.1035 31.8214H58.054C54.534 31.8214 51.6586 28.956 51.6586 25.3868C51.6586 23.0743 52.8485 21.0635 54.6828 19.9072C52.6997 17.7959 51.5595 15.031 51.5595 11.9645ZM90.8736 5.07736C87.0562 5.07736 83.9824 8.19415 83.9824 12.0147V23.9289C83.9824 25.2862 82.8917 26.3922 81.5532 26.3922C80.2146 26.3922 79.1239 25.2862 79.1239 23.9289V11.9645C79.1239 5.42925 84.379 0.150814 90.824 0.150814C97.2689 0.150814 102.524 5.42925 102.524 11.9645V23.8786C102.524 25.2359 101.433 26.3419 100.095 26.3419C98.7562 26.3419 97.6655 25.2359 97.6655 23.8786V11.9645C97.7647 8.14387 94.6414 5.07736 90.8736 5.07736ZM119.43 5.07736C115.513 5.07736 112.39 8.24441 112.39 12.065V14.5785C112.39 18.4494 115.513 21.5662 119.43 21.5662C120.768 21.5662 122.057 21.164 123.098 20.5105C124.238 19.8067 125.726 20.1586 126.42 21.3148C127.114 22.4711 126.767 23.9792 125.627 24.683C123.842 25.7889 121.71 26.4425 119.43 26.4425C112.885 26.4425 107.581 21.1137 107.581 14.5785V12.065C107.581 5.47952 112.935 0.201088 119.43 0.201088C125.032 0.201088 129.692 4.07194 130.931 9.3001L131.427 11.3612L121.115 15.584C119.876 16.0867 118.488 15.4834 117.942 14.2266C117.447 12.9699 118.041 11.5623 119.281 11.0596L125.478 8.54604C124.238 6.43466 122.008 5.07736 119.43 5.07736ZM146.003 5.07736C142.086 5.07736 138.963 8.24441 138.963 12.065V14.5785C138.963 18.4494 142.086 21.5662 146.003 21.5662C147.341 21.5662 148.63 21.164 149.671 20.5105C150.217 20.1586 150.663 19.8067 151.109 19.304C152.001 18.2986 153.538 18.2483 154.53 19.2034C155.521 20.1083 155.571 21.6667 154.629 22.6721C153.935 23.4262 153.092 24.13 152.2 24.683C150.415 25.7889 148.283 26.4425 146.003 26.4425C139.458 26.4425 134.154 21.1137 134.154 14.5785V12.065C134.154 5.47952 139.508 0.201088 146.003 0.201088C151.605 0.201088 156.265 4.07194 157.504 9.3001L158 11.3612L147.688 15.584C146.449 16.0867 145.061 15.4834 144.515 14.2266C144.019 12.9699 144.614 11.5623 145.854 11.0596L152.051 8.54604C150.762 6.43466 148.58 5.07736 146.003 5.07736Z" fill="white"/>
</svg>
</body>
</html>
"""
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list))
home_dir = os.path.expanduser("~")
output_file = os.path.join(home_dir, "graph_visualization.html")
with open(output_file, "w") as f:
f.write(html_content)
print(f"Graph visualization saved as {output_file}")
return html_content

View file

@ -13,6 +13,7 @@ import matplotlib.pyplot as plt
import logging
import sys
import json
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine
@ -336,86 +337,6 @@ def style_and_render_graph(p, G, layout_positions, node_attribute, node_colors,
return graph_renderer
async def create_cognee_style_network_with_logo(
G,
output_filename="cognee_network_with_logo.html",
title="Cognee-Style Network",
label="group",
layout_func=nx.spring_layout,
layout_scale=3.0,
logo_alpha=0.1,
bokeh_object=False,
):
"""
Create a Cognee-inspired network visualization with an embedded logo.
"""
from bokeh.plotting import figure, from_networkx
from bokeh.models import Circle, MultiLine, HoverTool, ColumnDataSource, Range1d
from bokeh.plotting import output_file, show
from bokeh.embed import file_html
from bokeh.resources import CDN
from bokeh.io import export_png
logging.info("Converting graph to serializable format...")
G = await convert_to_serializable_graph(G)
logging.info("Generating layout positions...")
layout_positions = generate_layout_positions(G, layout_func, layout_scale)
logging.info("Assigning node colors...")
palette = ["#6510F4", "#0DFF00", "#FFFFFF"]
node_colors, color_map = assign_node_colors(G, label, palette)
logging.info("Calculating centrality...")
centrality = nx.degree_centrality(G)
logging.info("Preparing Bokeh output...")
output_file(output_filename)
p = figure(
title=title,
tools="pan,wheel_zoom,save,reset,hover",
active_scroll="wheel_zoom",
width=1200,
height=900,
background_fill_color="#F4F4F4",
x_range=Range1d(-layout_scale, layout_scale),
y_range=Range1d(-layout_scale, layout_scale),
)
p.toolbar.logo = None
p.axis.visible = False
p.grid.visible = False
logging.info("Embedding logo into visualization...")
embed_logo(p, layout_scale, logo_alpha, "bottom_right")
embed_logo(p, layout_scale, logo_alpha, "top_left")
logging.info("Styling and rendering graph...")
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
logging.info("Adding hover tool...")
hover_tool = HoverTool(
tooltips=[
("Node", "@index"),
(label.capitalize(), f"@{label}"),
("Centrality", "@radius{0.00}"),
],
)
p.add_tools(hover_tool)
logging.info(f"Saving visualization to {output_filename}...")
html_content = file_html(p, CDN, title)
home_dir = os.path.expanduser("~")
# Construct the final output file path
output_filepath = os.path.join(home_dir, output_filename)
with open(output_filepath, "w") as f:
f.write(html_content)
return html_content
def graph_to_tuple(graph):
"""
Converts a networkx graph to a tuple of (nodes, edges).
@ -443,68 +364,3 @@ def setup_logging(log_level=logging.INFO):
root_logger.addHandler(stream_handler)
root_logger.setLevel(log_level)
# ---------------- Example Usage ----------------
if __name__ == "__main__":
import networkx as nx
# Create a sample graph
nodes = [
(1, {"group": "A"}),
(2, {"group": "A"}),
(3, {"group": "B"}),
(4, {"group": "B"}),
(5, {"group": "C"}),
]
edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 1)]
# Create a NetworkX graph
G = nx.Graph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)
G = graph_to_tuple(G)
import asyncio
output_html = asyncio.run(
create_cognee_style_network_with_logo(
G,
output_filename="example_network.html",
title="Example Cognee Network",
label="group", # Attribute to use for coloring nodes
layout_func=nx.spring_layout, # Layout function
layout_scale=3.0, # Scale for the layout
logo_alpha=0.2,
)
)
# Call the function
# output_html = await create_cognee_style_network_with_logo(
# G=G,
# output_filename="example_network.html",
# title="Example Cognee Network",
# node_attribute="group", # Attribute to use for coloring nodes
# layout_func=nx.spring_layout, # Layout function
# layout_scale=3.0, # Scale for the layout
# logo_alpha=0.2, # Transparency of the logo
# )
# Print the output filename
print("Network visualization saved as example_network.html")
# # Create a random geometric graph
# G = nx.random_geometric_graph(50, 0.3)
# # Assign random group attributes for coloring
# for i, node in enumerate(G.nodes()):
# G.nodes[node]["group"] = f"Group {i % 3 + 1}"
#
# create_cognee_graph(
# G,
# output_filename="cognee_style_network_with_logo.html",
# title="Cognee-Graph Network",
# node_attribute="group",
# layout_func=nx.spring_layout,
# layout_scale=3.0, # Replace with your logo file path
# )

View file

@ -0,0 +1,33 @@
import pytest
from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization,
)
@pytest.mark.asyncio
async def test_create_cognee_style_network_with_logo():
nodes_data = [
(1, {"pydantic_type": "Entity", "name": "Node1", "updated_at": 123, "created_at": 123}),
(
2,
{
"pydantic_type": "DocumentChunk",
"name": "Node2",
"updated_at": 123,
"created_at": 123,
},
),
]
edges_data = [
(1, 2, "related_to", {}),
]
graph_data = (nodes_data, edges_data)
html_output = await cognee_network_visualization(graph_data)
assert isinstance(html_output, str)
assert "<html>" in html_output
assert '<script src="https://d3js.org/d3.v5.min.js"></script>' in html_output
assert "var nodes =" in html_output
assert "var links =" in html_output

View file

@ -5,17 +5,12 @@ import pandas as pd
from unittest.mock import patch, mock_open
from io import BytesIO
from uuid import uuid4
from datetime import datetime, timezone
from cognee.shared.exceptions import IngestionError
from cognee.shared.utils import (
get_anonymous_id,
send_telemetry,
get_file_content_hash,
prepare_edges,
prepare_nodes,
create_cognee_style_network_with_logo,
graph_to_tuple,
)
@ -78,31 +73,3 @@ def test_prepare_nodes():
assert isinstance(nodes_df, pd.DataFrame)
assert len(nodes_df) == 1
@pytest.mark.asyncio
async def test_create_cognee_style_network_with_logo():
import networkx as nx
from unittest.mock import patch
from io import BytesIO
# Create a sample graph
graph = nx.Graph()
graph.add_node(1, group="A")
graph.add_node(2, group="B")
graph.add_edge(1, 2)
# Convert the graph to a tuple format for serialization
graph_tuple = graph_to_tuple(graph)
result = await create_cognee_style_network_with_logo(
graph_tuple,
title="Test Network",
layout_func=nx.spring_layout,
layout_scale=3.0,
logo_alpha=0.5,
)
assert result is not None
assert isinstance(result, str)
assert len(result) > 0