Fix visualization

This commit is contained in:
vasilije 2025-01-08 13:13:52 +01:00
parent 0ff9ffa11b
commit 41b1486cff
9 changed files with 38 additions and 41 deletions

View file

@ -24,7 +24,7 @@ from cognee.api.v1.users.routers import (
get_reset_password_router, get_reset_password_router,
get_verify_router, get_verify_router,
get_users_router, get_users_router,
get_visualize_router get_visualize_router,
) )
from contextlib import asynccontextmanager from contextlib import asynccontextmanager

View file

@ -110,7 +110,7 @@ async def run_cognify_pipeline(
summarization_model=cognee_config.summarization_model, summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10}, task_config={"batch_size": 10},
), ),
Task(add_data_points, only_root = True, task_config = { "batch_size": 10 }), Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
] ]
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline") pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")

View file

@ -11,21 +11,22 @@ from cognee.modules.users.methods import get_authenticated_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_visualize_router() -> APIRouter: def get_visualize_router() -> APIRouter:
router = APIRouter() router = APIRouter()
@router.post("/", response_model=None) @router.post("/", response_model=None)
async def visualize( async def visualize(
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
): ):
"""This endpoint is responsible for adding data to the graph.""" """This endpoint is responsible for adding data to the graph."""
from cognee.api.v1.visualize import visualize_graph from cognee.api.v1.visualize import visualize_graph
try: try:
html_visualization = await visualize_graph() html_visualization = await visualize_graph()
return html_visualization return html_visualization
except Exception as error: except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)}) return JSONResponse(status_code=409, content={"error": str(error)})
return router return router

View file

@ -3,8 +3,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
import logging import logging
async def visualize_graph(label: str = "name"):
async def visualize_graph(label:str="name"):
""" """ """ """
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data() graph_data = await graph_engine.get_graph_data()
@ -12,4 +11,4 @@ async def visualize_graph(label:str="name"):
graph = await create_cognee_style_network_with_logo(graph_data, label=label) graph = await create_cognee_style_network_with_logo(graph_data, label=label)
return graph return graph

View file

@ -8,6 +8,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config from cognee.infrastructure.llm.config import get_llm_config
import litellm import litellm
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
"""Adapter for Generic API LLM provider API""" """Adapter for Generic API LLM provider API"""
@ -27,8 +28,7 @@ class GenericAPIAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion) self.aclient = instructor.from_litellm(litellm.acompletion)
else: else:
self.aclient = instructor.from_litellm(litellm.acompletion) self.aclient = instructor.from_litellm(litellm.acompletion)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
@ -49,8 +49,6 @@ class GenericAPIAdapter(LLMInterface):
}, },
], ],
max_retries=5, max_retries=5,
api_base = self.endpoint, api_base=self.endpoint,
response_model=response_model response_model=response_model,
) )

View file

@ -305,7 +305,6 @@ async def convert_to_serializable_graph(G):
return new_G return new_G
def generate_layout_positions(G, layout_func, layout_scale): def generate_layout_positions(G, layout_func, layout_scale):
""" """
Generate layout positions for the graph using the specified layout function. Generate layout positions for the graph using the specified layout function.
@ -381,7 +380,7 @@ async def create_cognee_style_network_with_logo(
layout_func=nx.spring_layout, layout_func=nx.spring_layout,
layout_scale=3.0, layout_scale=3.0,
logo_alpha=0.1, logo_alpha=0.1,
bokeh_object = False, bokeh_object=False,
): ):
""" """
Create a Cognee-inspired network visualization with an embedded logo. Create a Cognee-inspired network visualization with an embedded logo.
@ -389,7 +388,6 @@ async def create_cognee_style_network_with_logo(
logging.info("Converting graph to serializable format...") logging.info("Converting graph to serializable format...")
G = await convert_to_serializable_graph(G) G = await convert_to_serializable_graph(G)
logging.info("Generating layout positions...") logging.info("Generating layout positions...")
layout_positions = generate_layout_positions(G, layout_func, layout_scale) layout_positions = generate_layout_positions(G, layout_func, layout_scale)
@ -420,7 +418,6 @@ async def create_cognee_style_network_with_logo(
embed_logo(p, layout_scale, logo_alpha, "bottom_right") embed_logo(p, layout_scale, logo_alpha, "bottom_right")
embed_logo(p, layout_scale, logo_alpha, "top_left") embed_logo(p, layout_scale, logo_alpha, "top_left")
logging.info("Styling and rendering graph...") logging.info("Styling and rendering graph...")
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality) style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
@ -434,7 +431,6 @@ async def create_cognee_style_network_with_logo(
) )
p.add_tools(hover_tool) p.add_tools(hover_tool)
logging.info(f"Saving visualization to {output_filename}...") logging.info(f"Saving visualization to {output_filename}...")
html_content = file_html(p, CDN, title) html_content = file_html(p, CDN, title)
with open(output_filename, "w") as f: with open(output_filename, "w") as f:
@ -442,8 +438,6 @@ async def create_cognee_style_network_with_logo(
logging.info("Visualization complete.") logging.info("Visualization complete.")
if bokeh_object: if bokeh_object:
return p return p
return html_content return html_content
@ -461,9 +455,8 @@ def graph_to_tuple(graph):
return (nodes, edges) return (nodes, edges)
def setup_logging(log_level=logging.INFO): def setup_logging(log_level=logging.INFO):
""" This method sets up the logging configuration. """ """This method sets up the logging configuration."""
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n") formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n")
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
@ -498,12 +491,17 @@ if __name__ == "__main__":
import asyncio import asyncio
output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html", output_html = asyncio.run(
title="Example Cognee Network", create_cognee_style_network_with_logo(
node_attribute="group", # Attribute to use for coloring nodes G,
layout_func=nx.spring_layout, # Layout function output_filename="example_network.html",
layout_scale=3.0, # Scale for the layout title="Example Cognee Network",
logo_alpha=0.2, )) node_attribute="group", # Attribute to use for coloring nodes
layout_func=nx.spring_layout, # Layout function
layout_scale=3.0, # Scale for the layout
logo_alpha=0.2,
)
)
# Call the function # Call the function
# output_html = await create_cognee_style_network_with_logo( # output_html = await create_cognee_style_network_with_logo(

View file

@ -20,7 +20,7 @@ async def add_data_points(data_points: list[DataPoint], only_root=False):
added_nodes=added_nodes, added_nodes=added_nodes,
added_edges=added_edges, added_edges=added_edges,
visited_properties=visited_properties, visited_properties=visited_properties,
only_root = only_root, only_root=only_root,
) )
for data_point in data_points for data_point in data_points
] ]

View file

@ -102,6 +102,7 @@ def test_prepare_nodes():
assert isinstance(nodes_df, pd.DataFrame) assert isinstance(nodes_df, pd.DataFrame)
assert len(nodes_df) == 1 assert len(nodes_df) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_cognee_style_network_with_logo(): async def test_create_cognee_style_network_with_logo():
import networkx as nx import networkx as nx

View file

@ -191,7 +191,7 @@ async def main(enable_steps):
print(format_triplets(results)) print(format_triplets(results))
if __name__ == '__main__': if __name__ == "__main__":
setup_logging(logging.ERROR) setup_logging(logging.ERROR)
rebuild_kg = True rebuild_kg = True