diff --git a/cognee/api/client.py b/cognee/api/client.py index c39eef121..8752e2318 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -24,6 +24,7 @@ from cognee.api.v1.users.routers import ( get_reset_password_router, get_verify_router, get_users_router, + get_visualize_router ) from contextlib import asynccontextmanager @@ -166,6 +167,8 @@ app.include_router(get_search_router(), prefix="/api/v1/search", tags=["search"] app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"]) +app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"]) + def start_api_server(host: str = "0.0.0.0", port: int = 8000): """ diff --git a/cognee/api/v1/users/routers/__init__.py b/cognee/api/v1/users/routers/__init__.py index 482aac265..ae4b7ca56 100644 --- a/cognee/api/v1/users/routers/__init__.py +++ b/cognee/api/v1/users/routers/__init__.py @@ -3,3 +3,4 @@ from .get_register_router import get_register_router from .get_reset_password_router import get_reset_password_router from .get_users_router import get_users_router from .get_verify_router import get_verify_router +from .get_visualize_router import get_visualize_router diff --git a/cognee/api/v1/users/routers/get_visualize_router.py b/cognee/api/v1/users/routers/get_visualize_router.py new file mode 100644 index 000000000..4c4d0b6f7 --- /dev/null +++ b/cognee/api/v1/users/routers/get_visualize_router.py @@ -0,0 +1,31 @@ +from fastapi import Form, UploadFile, Depends +from fastapi.responses import JSONResponse +from fastapi import APIRouter +from typing import List +import aiohttp +import subprocess +import logging +import os +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_authenticated_user + +logger = logging.getLogger(__name__) + +def get_visualize_router() -> APIRouter: + router = APIRouter() + + @router.post("/", response_model=None) + async def visualize( + user: User = Depends(get_authenticated_user), + ): + """This endpoint is responsible for adding data to the graph.""" + from cognee.api.v1.visualize import visualize_graph + + try: + html_visualization = await visualize_graph() + return html_visualization + + except Exception as error: + return JSONResponse(status_code=409, content={"error": str(error)}) + + return router \ No newline at end of file diff --git a/cognee/api/v1/visualize/visualize.py b/cognee/api/v1/visualize/visualize.py index 5e8290723..d9be2a3af 100644 --- a/cognee/api/v1/visualize/visualize.py +++ b/cognee/api/v1/visualize/visualize.py @@ -4,12 +4,12 @@ import logging -async def visualize_graph(bokeh_object): +async def visualize_graph(label:str="name"): """ """ graph_engine = await get_graph_engine() graph_data = await graph_engine.get_graph_data() logging.info(graph_data) - graph = await create_cognee_style_network_with_logo(graph_data, bokeh_object=bokeh_object) + graph = await create_cognee_style_network_with_logo(graph_data, label=label) return graph \ No newline at end of file diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index a34055d8e..a08b3d536 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -283,21 +283,28 @@ async def convert_to_serializable_graph(G): """ (nodes, edges) = G - networkx_graph = nx.MultiDiGraph() + networkx_graph = nx.MultiDiGraph() networkx_graph.add_nodes_from(nodes) networkx_graph.add_edges_from(edges) - new_G = nx.MultiDiGraph() if isinstance(G, nx.MultiDiGraph) else nx.Graph() - for node, data in new_G.nodes(data=True): + # Create a new graph to store the serializable version + new_G = nx.MultiDiGraph() + + # Serialize nodes + for node, data in networkx_graph.nodes(data=True): serializable_data = {k: str(v) for k, v in data.items()} new_G.add_node(str(node), **serializable_data) - for u, v, data in new_G.edges(data=True): + + # Serialize edges + for u, v, data in networkx_graph.edges(data=True): serializable_data = {k: str(v) for k, v in data.items()} new_G.add_edge(str(u), str(v), **serializable_data) + return new_G + def generate_layout_positions(G, layout_func, layout_scale): """ Generate layout positions for the graph using the specified layout function. @@ -315,7 +322,7 @@ def assign_node_colors(G, node_attribute, palette): return [color_map[G.nodes[node].get(node_attribute, "Unknown")] for node in G.nodes], color_map -def embed_logo(p, layout_scale, logo_alpha): +def embed_logo(p, layout_scale, logo_alpha, position): """ Embed a logo into the graph visualization as a watermark. """ @@ -336,7 +343,7 @@ def embed_logo(p, layout_scale, logo_alpha): y=layout_scale * 0.5, w=layout_scale, h=layout_scale, - anchor="center", + anchor=position, global_alpha=logo_alpha, ) @@ -369,7 +376,7 @@ async def create_cognee_style_network_with_logo( G, output_filename="cognee_network_with_logo.html", title="Cognee-Style Network", - node_attribute="group", + label="group", layout_func=nx.spring_layout, layout_scale=3.0, logo_alpha=0.1, @@ -387,7 +394,7 @@ async def create_cognee_style_network_with_logo( logging.info("Assigning node colors...") palette = ["#6510F4", "#0DFF00", "#FFFFFF"] - node_colors, color_map = assign_node_colors(G, node_attribute, palette) + node_colors, color_map = assign_node_colors(G, label, palette) logging.info("Calculating centrality...") centrality = nx.degree_centrality(G) @@ -409,16 +416,18 @@ async def create_cognee_style_network_with_logo( p.grid.visible = False logging.info("Embedding logo into visualization...") - embed_logo(p, layout_scale, logo_alpha) + embed_logo(p, layout_scale, logo_alpha, "bottom_right") + embed_logo(p, layout_scale, logo_alpha, "top_left") + logging.info("Styling and rendering graph...") - style_and_render_graph(p, G, layout_positions, node_attribute, node_colors, centrality) + style_and_render_graph(p, G, layout_positions, label, node_colors, centrality) logging.info("Adding hover tool...") hover_tool = HoverTool( tooltips=[ ("Node", "@index"), - (node_attribute.capitalize(), f"@{node_attribute}"), + (label.capitalize(), f"@{label}"), ("Centrality", "@radius{0.00}"), ], ) @@ -429,16 +438,10 @@ async def create_cognee_style_network_with_logo( html_content = file_html(p, CDN, title) with open(output_filename, "w") as f: f.write(html_content) - from bokeh.io import export_png - from IPython.display import Image, display logging.info("Visualization complete.") - png_filename = output_filename.replace(".html", ".png") - export_png(p, filename=png_filename) - logging.info(f"Saved PNG image to {png_filename}") - # Display the PNG image as a popup - display(Image(png_filename)) + if bokeh_object: return p @@ -478,16 +481,27 @@ if __name__ == "__main__": G.add_nodes_from(nodes) G.add_edges_from(edges) - # Call the function - output_html = create_cognee_style_network_with_logo( - G=G, - output_filename="example_network.html", + G = graph_to_tuple(G) + + import asyncio + + output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html", title="Example Cognee Network", node_attribute="group", # Attribute to use for coloring nodes layout_func=nx.spring_layout, # Layout function layout_scale=3.0, # Scale for the layout - logo_alpha=0.2, # Transparency of the logo - ) + logo_alpha=0.2, )) + + # Call the function + # output_html = await create_cognee_style_network_with_logo( + # G=G, + # output_filename="example_network.html", + # title="Example Cognee Network", + # node_attribute="group", # Attribute to use for coloring nodes + # layout_func=nx.spring_layout, # Layout function + # layout_scale=3.0, # Scale for the layout + # logo_alpha=0.2, # Transparency of the logo + # ) # Print the output filename print("Network visualization saved as example_network.html")