Fix visualization

This commit is contained in:
vasilije 2025-01-08 13:13:52 +01:00
parent 0ff9ffa11b
commit 41b1486cff
9 changed files with 38 additions and 41 deletions

View file

@ -24,7 +24,7 @@ from cognee.api.v1.users.routers import (
get_reset_password_router,
get_verify_router,
get_users_router,
get_visualize_router
get_visualize_router,
)
from contextlib import asynccontextmanager

View file

@ -110,7 +110,7 @@ async def run_cognify_pipeline(
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
),
Task(add_data_points, only_root = True, task_config = { "batch_size": 10 }),
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
]
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")

View file

@ -11,6 +11,7 @@ from cognee.modules.users.methods import get_authenticated_user
logger = logging.getLogger(__name__)
def get_visualize_router() -> APIRouter:
router = APIRouter()

View file

@ -3,8 +3,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
import logging
async def visualize_graph(label:str="name"):
async def visualize_graph(label: str = "name"):
""" """
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()

View file

@ -8,6 +8,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
import litellm
class GenericAPIAdapter(LLMInterface):
"""Adapter for Generic API LLM provider API"""
@ -29,7 +30,6 @@ class GenericAPIAdapter(LLMInterface):
else:
self.aclient = instructor.from_litellm(litellm.acompletion)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -49,8 +49,6 @@ class GenericAPIAdapter(LLMInterface):
},
],
max_retries=5,
api_base = self.endpoint,
response_model=response_model
api_base=self.endpoint,
response_model=response_model,
)

View file

@ -305,7 +305,6 @@ async def convert_to_serializable_graph(G):
return new_G
def generate_layout_positions(G, layout_func, layout_scale):
"""
Generate layout positions for the graph using the specified layout function.
@ -381,7 +380,7 @@ async def create_cognee_style_network_with_logo(
layout_func=nx.spring_layout,
layout_scale=3.0,
logo_alpha=0.1,
bokeh_object = False,
bokeh_object=False,
):
"""
Create a Cognee-inspired network visualization with an embedded logo.
@ -389,7 +388,6 @@ async def create_cognee_style_network_with_logo(
logging.info("Converting graph to serializable format...")
G = await convert_to_serializable_graph(G)
logging.info("Generating layout positions...")
layout_positions = generate_layout_positions(G, layout_func, layout_scale)
@ -420,7 +418,6 @@ async def create_cognee_style_network_with_logo(
embed_logo(p, layout_scale, logo_alpha, "bottom_right")
embed_logo(p, layout_scale, logo_alpha, "top_left")
logging.info("Styling and rendering graph...")
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
@ -434,7 +431,6 @@ async def create_cognee_style_network_with_logo(
)
p.add_tools(hover_tool)
logging.info(f"Saving visualization to {output_filename}...")
html_content = file_html(p, CDN, title)
with open(output_filename, "w") as f:
@ -442,8 +438,6 @@ async def create_cognee_style_network_with_logo(
logging.info("Visualization complete.")
if bokeh_object:
return p
return html_content
@ -461,9 +455,8 @@ def graph_to_tuple(graph):
return (nodes, edges)
def setup_logging(log_level=logging.INFO):
""" This method sets up the logging configuration. """
"""This method sets up the logging configuration."""
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n")
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
@ -498,12 +491,17 @@ if __name__ == "__main__":
import asyncio
output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html",
output_html = asyncio.run(
create_cognee_style_network_with_logo(
G,
output_filename="example_network.html",
title="Example Cognee Network",
node_attribute="group", # Attribute to use for coloring nodes
layout_func=nx.spring_layout, # Layout function
layout_scale=3.0, # Scale for the layout
logo_alpha=0.2, ))
logo_alpha=0.2,
)
)
# Call the function
# output_html = await create_cognee_style_network_with_logo(

View file

@ -20,7 +20,7 @@ async def add_data_points(data_points: list[DataPoint], only_root=False):
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
only_root = only_root,
only_root=only_root,
)
for data_point in data_points
]

View file

@ -102,6 +102,7 @@ def test_prepare_nodes():
assert isinstance(nodes_df, pd.DataFrame)
assert len(nodes_df) == 1
@pytest.mark.asyncio
async def test_create_cognee_style_network_with_logo():
import networkx as nx

View file

@ -191,7 +191,7 @@ async def main(enable_steps):
print(format_triplets(results))
if __name__ == '__main__':
if __name__ == "__main__":
setup_logging(logging.ERROR)
rebuild_kg = True