Fix visualization
This commit is contained in:
parent
1b96a71d5a
commit
4c21dd0cce
5 changed files with 75 additions and 26 deletions
|
|
@ -24,6 +24,7 @@ from cognee.api.v1.users.routers import (
|
|||
get_reset_password_router,
|
||||
get_verify_router,
|
||||
get_users_router,
|
||||
get_visualize_router
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
|
@ -166,6 +167,8 @@ app.include_router(get_search_router(), prefix="/api/v1/search", tags=["search"]
|
|||
|
||||
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
|
||||
|
||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||
|
||||
|
||||
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ from .get_register_router import get_register_router
|
|||
from .get_reset_password_router import get_reset_password_router
|
||||
from .get_users_router import get_users_router
|
||||
from .get_verify_router import get_verify_router
|
||||
from .get_visualize_router import get_visualize_router
|
||||
|
|
|
|||
31
cognee/api/v1/users/routers/get_visualize_router.py
Normal file
31
cognee/api/v1/users/routers/get_visualize_router.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from fastapi import Form, UploadFile, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter
|
||||
from typing import List
|
||||
import aiohttp
|
||||
import subprocess
|
||||
import logging
|
||||
import os
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_visualize_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=None)
|
||||
async def visualize(
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""This endpoint is responsible for adding data to the graph."""
|
||||
from cognee.api.v1.visualize import visualize_graph
|
||||
|
||||
try:
|
||||
html_visualization = await visualize_graph()
|
||||
return html_visualization
|
||||
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
@ -4,12 +4,12 @@ import logging
|
|||
|
||||
|
||||
|
||||
async def visualize_graph(bokeh_object):
|
||||
async def visualize_graph(label:str="name"):
|
||||
""" """
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_data = await graph_engine.get_graph_data()
|
||||
logging.info(graph_data)
|
||||
|
||||
graph = await create_cognee_style_network_with_logo(graph_data, bokeh_object=bokeh_object)
|
||||
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
|
||||
|
||||
return graph
|
||||
|
|
@ -283,21 +283,28 @@ async def convert_to_serializable_graph(G):
|
|||
"""
|
||||
|
||||
(nodes, edges) = G
|
||||
networkx_graph = nx.MultiDiGraph()
|
||||
|
||||
networkx_graph = nx.MultiDiGraph()
|
||||
networkx_graph.add_nodes_from(nodes)
|
||||
networkx_graph.add_edges_from(edges)
|
||||
|
||||
new_G = nx.MultiDiGraph() if isinstance(G, nx.MultiDiGraph) else nx.Graph()
|
||||
for node, data in new_G.nodes(data=True):
|
||||
# Create a new graph to store the serializable version
|
||||
new_G = nx.MultiDiGraph()
|
||||
|
||||
# Serialize nodes
|
||||
for node, data in networkx_graph.nodes(data=True):
|
||||
serializable_data = {k: str(v) for k, v in data.items()}
|
||||
new_G.add_node(str(node), **serializable_data)
|
||||
for u, v, data in new_G.edges(data=True):
|
||||
|
||||
# Serialize edges
|
||||
for u, v, data in networkx_graph.edges(data=True):
|
||||
serializable_data = {k: str(v) for k, v in data.items()}
|
||||
new_G.add_edge(str(u), str(v), **serializable_data)
|
||||
|
||||
return new_G
|
||||
|
||||
|
||||
|
||||
def generate_layout_positions(G, layout_func, layout_scale):
|
||||
"""
|
||||
Generate layout positions for the graph using the specified layout function.
|
||||
|
|
@ -315,7 +322,7 @@ def assign_node_colors(G, node_attribute, palette):
|
|||
return [color_map[G.nodes[node].get(node_attribute, "Unknown")] for node in G.nodes], color_map
|
||||
|
||||
|
||||
def embed_logo(p, layout_scale, logo_alpha):
|
||||
def embed_logo(p, layout_scale, logo_alpha, position):
|
||||
"""
|
||||
Embed a logo into the graph visualization as a watermark.
|
||||
"""
|
||||
|
|
@ -336,7 +343,7 @@ def embed_logo(p, layout_scale, logo_alpha):
|
|||
y=layout_scale * 0.5,
|
||||
w=layout_scale,
|
||||
h=layout_scale,
|
||||
anchor="center",
|
||||
anchor=position,
|
||||
global_alpha=logo_alpha,
|
||||
)
|
||||
|
||||
|
|
@ -369,7 +376,7 @@ async def create_cognee_style_network_with_logo(
|
|||
G,
|
||||
output_filename="cognee_network_with_logo.html",
|
||||
title="Cognee-Style Network",
|
||||
node_attribute="group",
|
||||
label="group",
|
||||
layout_func=nx.spring_layout,
|
||||
layout_scale=3.0,
|
||||
logo_alpha=0.1,
|
||||
|
|
@ -387,7 +394,7 @@ async def create_cognee_style_network_with_logo(
|
|||
|
||||
logging.info("Assigning node colors...")
|
||||
palette = ["#6510F4", "#0DFF00", "#FFFFFF"]
|
||||
node_colors, color_map = assign_node_colors(G, node_attribute, palette)
|
||||
node_colors, color_map = assign_node_colors(G, label, palette)
|
||||
|
||||
logging.info("Calculating centrality...")
|
||||
centrality = nx.degree_centrality(G)
|
||||
|
|
@ -409,16 +416,18 @@ async def create_cognee_style_network_with_logo(
|
|||
p.grid.visible = False
|
||||
|
||||
logging.info("Embedding logo into visualization...")
|
||||
embed_logo(p, layout_scale, logo_alpha)
|
||||
embed_logo(p, layout_scale, logo_alpha, "bottom_right")
|
||||
embed_logo(p, layout_scale, logo_alpha, "top_left")
|
||||
|
||||
|
||||
logging.info("Styling and rendering graph...")
|
||||
style_and_render_graph(p, G, layout_positions, node_attribute, node_colors, centrality)
|
||||
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
|
||||
|
||||
logging.info("Adding hover tool...")
|
||||
hover_tool = HoverTool(
|
||||
tooltips=[
|
||||
("Node", "@index"),
|
||||
(node_attribute.capitalize(), f"@{node_attribute}"),
|
||||
(label.capitalize(), f"@{label}"),
|
||||
("Centrality", "@radius{0.00}"),
|
||||
],
|
||||
)
|
||||
|
|
@ -429,16 +438,10 @@ async def create_cognee_style_network_with_logo(
|
|||
html_content = file_html(p, CDN, title)
|
||||
with open(output_filename, "w") as f:
|
||||
f.write(html_content)
|
||||
from bokeh.io import export_png
|
||||
from IPython.display import Image, display
|
||||
|
||||
logging.info("Visualization complete.")
|
||||
png_filename = output_filename.replace(".html", ".png")
|
||||
export_png(p, filename=png_filename)
|
||||
logging.info(f"Saved PNG image to {png_filename}")
|
||||
|
||||
# Display the PNG image as a popup
|
||||
display(Image(png_filename))
|
||||
|
||||
|
||||
if bokeh_object:
|
||||
return p
|
||||
|
|
@ -478,16 +481,27 @@ if __name__ == "__main__":
|
|||
G.add_nodes_from(nodes)
|
||||
G.add_edges_from(edges)
|
||||
|
||||
# Call the function
|
||||
output_html = create_cognee_style_network_with_logo(
|
||||
G=G,
|
||||
output_filename="example_network.html",
|
||||
G = graph_to_tuple(G)
|
||||
|
||||
import asyncio
|
||||
|
||||
output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html",
|
||||
title="Example Cognee Network",
|
||||
node_attribute="group", # Attribute to use for coloring nodes
|
||||
layout_func=nx.spring_layout, # Layout function
|
||||
layout_scale=3.0, # Scale for the layout
|
||||
logo_alpha=0.2, # Transparency of the logo
|
||||
)
|
||||
logo_alpha=0.2, ))
|
||||
|
||||
# Call the function
|
||||
# output_html = await create_cognee_style_network_with_logo(
|
||||
# G=G,
|
||||
# output_filename="example_network.html",
|
||||
# title="Example Cognee Network",
|
||||
# node_attribute="group", # Attribute to use for coloring nodes
|
||||
# layout_func=nx.spring_layout, # Layout function
|
||||
# layout_scale=3.0, # Scale for the layout
|
||||
# logo_alpha=0.2, # Transparency of the logo
|
||||
# )
|
||||
|
||||
# Print the output filename
|
||||
print("Network visualization saved as example_network.html")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue