Fix visualization

This commit is contained in:
vasilije 2025-01-08 13:13:52 +01:00
parent 0ff9ffa11b
commit 41b1486cff
9 changed files with 38 additions and 41 deletions

View file

@ -24,7 +24,7 @@ from cognee.api.v1.users.routers import (
get_reset_password_router, get_reset_password_router,
get_verify_router, get_verify_router,
get_users_router, get_users_router,
get_visualize_router get_visualize_router,
) )
from contextlib import asynccontextmanager from contextlib import asynccontextmanager

View file

@ -11,6 +11,7 @@ from cognee.modules.users.methods import get_authenticated_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_visualize_router() -> APIRouter: def get_visualize_router() -> APIRouter:
router = APIRouter() router = APIRouter()

View file

@ -3,7 +3,6 @@ from cognee.infrastructure.databases.graph import get_graph_engine
import logging import logging
async def visualize_graph(label: str = "name"): async def visualize_graph(label: str = "name"):
""" """ """ """
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()

View file

@ -8,6 +8,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config from cognee.infrastructure.llm.config import get_llm_config
import litellm import litellm
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
"""Adapter for Generic API LLM provider API""" """Adapter for Generic API LLM provider API"""
@ -29,7 +30,6 @@ class GenericAPIAdapter(LLMInterface):
else: else:
self.aclient = instructor.from_litellm(litellm.acompletion) self.aclient = instructor.from_litellm(litellm.acompletion)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -50,7 +50,5 @@ class GenericAPIAdapter(LLMInterface):
], ],
max_retries=5, max_retries=5,
api_base=self.endpoint, api_base=self.endpoint,
response_model=response_model response_model=response_model,
) )

View file

@ -305,7 +305,6 @@ async def convert_to_serializable_graph(G):
return new_G return new_G
def generate_layout_positions(G, layout_func, layout_scale): def generate_layout_positions(G, layout_func, layout_scale):
""" """
Generate layout positions for the graph using the specified layout function. Generate layout positions for the graph using the specified layout function.
@ -389,7 +388,6 @@ async def create_cognee_style_network_with_logo(
logging.info("Converting graph to serializable format...") logging.info("Converting graph to serializable format...")
G = await convert_to_serializable_graph(G) G = await convert_to_serializable_graph(G)
logging.info("Generating layout positions...") logging.info("Generating layout positions...")
layout_positions = generate_layout_positions(G, layout_func, layout_scale) layout_positions = generate_layout_positions(G, layout_func, layout_scale)
@ -420,7 +418,6 @@ async def create_cognee_style_network_with_logo(
embed_logo(p, layout_scale, logo_alpha, "bottom_right") embed_logo(p, layout_scale, logo_alpha, "bottom_right")
embed_logo(p, layout_scale, logo_alpha, "top_left") embed_logo(p, layout_scale, logo_alpha, "top_left")
logging.info("Styling and rendering graph...") logging.info("Styling and rendering graph...")
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality) style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
@ -434,7 +431,6 @@ async def create_cognee_style_network_with_logo(
) )
p.add_tools(hover_tool) p.add_tools(hover_tool)
logging.info(f"Saving visualization to {output_filename}...") logging.info(f"Saving visualization to {output_filename}...")
html_content = file_html(p, CDN, title) html_content = file_html(p, CDN, title)
with open(output_filename, "w") as f: with open(output_filename, "w") as f:
@ -442,8 +438,6 @@ async def create_cognee_style_network_with_logo(
logging.info("Visualization complete.") logging.info("Visualization complete.")
if bokeh_object: if bokeh_object:
return p return p
return html_content return html_content
@ -461,7 +455,6 @@ def graph_to_tuple(graph):
return (nodes, edges) return (nodes, edges)
def setup_logging(log_level=logging.INFO): def setup_logging(log_level=logging.INFO):
"""This method sets up the logging configuration.""" """This method sets up the logging configuration."""
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n") formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n")
@ -498,12 +491,17 @@ if __name__ == "__main__":
import asyncio import asyncio
output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html", output_html = asyncio.run(
create_cognee_style_network_with_logo(
G,
output_filename="example_network.html",
title="Example Cognee Network", title="Example Cognee Network",
node_attribute="group", # Attribute to use for coloring nodes node_attribute="group", # Attribute to use for coloring nodes
layout_func=nx.spring_layout, # Layout function layout_func=nx.spring_layout, # Layout function
layout_scale=3.0, # Scale for the layout layout_scale=3.0, # Scale for the layout
logo_alpha=0.2, )) logo_alpha=0.2,
)
)
# Call the function # Call the function
# output_html = await create_cognee_style_network_with_logo( # output_html = await create_cognee_style_network_with_logo(

View file

@ -102,6 +102,7 @@ def test_prepare_nodes():
assert isinstance(nodes_df, pd.DataFrame) assert isinstance(nodes_df, pd.DataFrame)
assert len(nodes_df) == 1 assert len(nodes_df) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_cognee_style_network_with_logo(): async def test_create_cognee_style_network_with_logo():
import networkx as nx import networkx as nx

View file

@ -191,7 +191,7 @@ async def main(enable_steps):
print(format_triplets(results)) print(format_triplets(results))
if __name__ == '__main__': if __name__ == "__main__":
setup_logging(logging.ERROR) setup_logging(logging.ERROR)
rebuild_kg = True rebuild_kg = True