Fix visualization

This commit is contained in:
vasilije 2025-01-07 15:21:08 +01:00
parent 1b96a71d5a
commit 4c21dd0cce
5 changed files with 75 additions and 26 deletions

View file

@ -24,6 +24,7 @@ from cognee.api.v1.users.routers import (
get_reset_password_router,
get_verify_router,
get_users_router,
get_visualize_router
)
from contextlib import asynccontextmanager
@ -166,6 +167,8 @@ app.include_router(get_search_router(), prefix="/api/v1/search", tags=["search"]
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
"""

View file

@ -3,3 +3,4 @@ from .get_register_router import get_register_router
from .get_reset_password_router import get_reset_password_router
from .get_users_router import get_users_router
from .get_verify_router import get_verify_router
from .get_visualize_router import get_visualize_router

View file

@ -0,0 +1,31 @@
from fastapi import Form, UploadFile, Depends
from fastapi.responses import JSONResponse
from fastapi import APIRouter
from typing import List
import aiohttp
import subprocess
import logging
import os
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
logger = logging.getLogger(__name__)
def get_visualize_router() -> APIRouter:
router = APIRouter()
@router.post("/", response_model=None)
async def visualize(
user: User = Depends(get_authenticated_user),
):
"""This endpoint is responsible for adding data to the graph."""
from cognee.api.v1.visualize import visualize_graph
try:
html_visualization = await visualize_graph()
return html_visualization
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})
return router

View file

@ -4,12 +4,12 @@ import logging
async def visualize_graph(bokeh_object):
async def visualize_graph(label:str="name"):
""" """
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
logging.info(graph_data)
graph = await create_cognee_style_network_with_logo(graph_data, bokeh_object=bokeh_object)
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
return graph

View file

@ -283,21 +283,28 @@ async def convert_to_serializable_graph(G):
"""
(nodes, edges) = G
networkx_graph = nx.MultiDiGraph()
networkx_graph = nx.MultiDiGraph()
networkx_graph.add_nodes_from(nodes)
networkx_graph.add_edges_from(edges)
new_G = nx.MultiDiGraph() if isinstance(G, nx.MultiDiGraph) else nx.Graph()
for node, data in new_G.nodes(data=True):
# Create a new graph to store the serializable version
new_G = nx.MultiDiGraph()
# Serialize nodes
for node, data in networkx_graph.nodes(data=True):
serializable_data = {k: str(v) for k, v in data.items()}
new_G.add_node(str(node), **serializable_data)
for u, v, data in new_G.edges(data=True):
# Serialize edges
for u, v, data in networkx_graph.edges(data=True):
serializable_data = {k: str(v) for k, v in data.items()}
new_G.add_edge(str(u), str(v), **serializable_data)
return new_G
def generate_layout_positions(G, layout_func, layout_scale):
"""
Generate layout positions for the graph using the specified layout function.
@ -315,7 +322,7 @@ def assign_node_colors(G, node_attribute, palette):
return [color_map[G.nodes[node].get(node_attribute, "Unknown")] for node in G.nodes], color_map
def embed_logo(p, layout_scale, logo_alpha):
def embed_logo(p, layout_scale, logo_alpha, position):
"""
Embed a logo into the graph visualization as a watermark.
"""
@ -336,7 +343,7 @@ def embed_logo(p, layout_scale, logo_alpha):
y=layout_scale * 0.5,
w=layout_scale,
h=layout_scale,
anchor="center",
anchor=position,
global_alpha=logo_alpha,
)
@ -369,7 +376,7 @@ async def create_cognee_style_network_with_logo(
G,
output_filename="cognee_network_with_logo.html",
title="Cognee-Style Network",
node_attribute="group",
label="group",
layout_func=nx.spring_layout,
layout_scale=3.0,
logo_alpha=0.1,
@ -387,7 +394,7 @@ async def create_cognee_style_network_with_logo(
logging.info("Assigning node colors...")
palette = ["#6510F4", "#0DFF00", "#FFFFFF"]
node_colors, color_map = assign_node_colors(G, node_attribute, palette)
node_colors, color_map = assign_node_colors(G, label, palette)
logging.info("Calculating centrality...")
centrality = nx.degree_centrality(G)
@ -409,16 +416,18 @@ async def create_cognee_style_network_with_logo(
p.grid.visible = False
logging.info("Embedding logo into visualization...")
embed_logo(p, layout_scale, logo_alpha)
embed_logo(p, layout_scale, logo_alpha, "bottom_right")
embed_logo(p, layout_scale, logo_alpha, "top_left")
logging.info("Styling and rendering graph...")
style_and_render_graph(p, G, layout_positions, node_attribute, node_colors, centrality)
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
logging.info("Adding hover tool...")
hover_tool = HoverTool(
tooltips=[
("Node", "@index"),
(node_attribute.capitalize(), f"@{node_attribute}"),
(label.capitalize(), f"@{label}"),
("Centrality", "@radius{0.00}"),
],
)
@ -429,16 +438,10 @@ async def create_cognee_style_network_with_logo(
html_content = file_html(p, CDN, title)
with open(output_filename, "w") as f:
f.write(html_content)
from bokeh.io import export_png
from IPython.display import Image, display
logging.info("Visualization complete.")
png_filename = output_filename.replace(".html", ".png")
export_png(p, filename=png_filename)
logging.info(f"Saved PNG image to {png_filename}")
# Display the PNG image as a popup
display(Image(png_filename))
if bokeh_object:
return p
@ -478,16 +481,27 @@ if __name__ == "__main__":
G.add_nodes_from(nodes)
G.add_edges_from(edges)
# Call the function
output_html = create_cognee_style_network_with_logo(
G=G,
output_filename="example_network.html",
G = graph_to_tuple(G)
import asyncio
output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html",
title="Example Cognee Network",
node_attribute="group", # Attribute to use for coloring nodes
layout_func=nx.spring_layout, # Layout function
layout_scale=3.0, # Scale for the layout
logo_alpha=0.2, # Transparency of the logo
)
logo_alpha=0.2, ))
# Call the function
# output_html = await create_cognee_style_network_with_logo(
# G=G,
# output_filename="example_network.html",
# title="Example Cognee Network",
# node_attribute="group", # Attribute to use for coloring nodes
# layout_func=nx.spring_layout, # Layout function
# layout_scale=3.0, # Scale for the layout
# logo_alpha=0.2, # Transparency of the logo
# )
# Print the output filename
print("Network visualization saved as example_network.html")