Fix visualization
This commit is contained in:
parent
0ff9ffa11b
commit
41b1486cff
9 changed files with 38 additions and 41 deletions
|
|
@ -24,7 +24,7 @@ from cognee.api.v1.users.routers import (
|
|||
get_reset_password_router,
|
||||
get_verify_router,
|
||||
get_users_router,
|
||||
get_visualize_router
|
||||
get_visualize_router,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_visualize_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
import logging
|
||||
|
||||
|
||||
|
||||
async def visualize_graph(label: str = "name"):
|
||||
""" """
|
||||
graph_engine = await get_graph_engine()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
|
|||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
import litellm
|
||||
|
||||
|
||||
class GenericAPIAdapter(LLMInterface):
|
||||
"""Adapter for Generic API LLM provider API"""
|
||||
|
||||
|
|
@ -29,7 +30,6 @@ class GenericAPIAdapter(LLMInterface):
|
|||
else:
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
|
||||
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
|
|
@ -50,7 +50,5 @@ class GenericAPIAdapter(LLMInterface):
|
|||
],
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -305,7 +305,6 @@ async def convert_to_serializable_graph(G):
|
|||
return new_G
|
||||
|
||||
|
||||
|
||||
def generate_layout_positions(G, layout_func, layout_scale):
|
||||
"""
|
||||
Generate layout positions for the graph using the specified layout function.
|
||||
|
|
@ -389,7 +388,6 @@ async def create_cognee_style_network_with_logo(
|
|||
logging.info("Converting graph to serializable format...")
|
||||
G = await convert_to_serializable_graph(G)
|
||||
|
||||
|
||||
logging.info("Generating layout positions...")
|
||||
layout_positions = generate_layout_positions(G, layout_func, layout_scale)
|
||||
|
||||
|
|
@ -420,7 +418,6 @@ async def create_cognee_style_network_with_logo(
|
|||
embed_logo(p, layout_scale, logo_alpha, "bottom_right")
|
||||
embed_logo(p, layout_scale, logo_alpha, "top_left")
|
||||
|
||||
|
||||
logging.info("Styling and rendering graph...")
|
||||
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
|
||||
|
||||
|
|
@ -434,7 +431,6 @@ async def create_cognee_style_network_with_logo(
|
|||
)
|
||||
p.add_tools(hover_tool)
|
||||
|
||||
|
||||
logging.info(f"Saving visualization to {output_filename}...")
|
||||
html_content = file_html(p, CDN, title)
|
||||
with open(output_filename, "w") as f:
|
||||
|
|
@ -442,8 +438,6 @@ async def create_cognee_style_network_with_logo(
|
|||
|
||||
logging.info("Visualization complete.")
|
||||
|
||||
|
||||
|
||||
if bokeh_object:
|
||||
return p
|
||||
return html_content
|
||||
|
|
@ -461,7 +455,6 @@ def graph_to_tuple(graph):
|
|||
return (nodes, edges)
|
||||
|
||||
|
||||
|
||||
def setup_logging(log_level=logging.INFO):
|
||||
"""This method sets up the logging configuration."""
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n")
|
||||
|
|
@ -498,12 +491,17 @@ if __name__ == "__main__":
|
|||
|
||||
import asyncio
|
||||
|
||||
output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html",
|
||||
output_html = asyncio.run(
|
||||
create_cognee_style_network_with_logo(
|
||||
G,
|
||||
output_filename="example_network.html",
|
||||
title="Example Cognee Network",
|
||||
node_attribute="group", # Attribute to use for coloring nodes
|
||||
layout_func=nx.spring_layout, # Layout function
|
||||
layout_scale=3.0, # Scale for the layout
|
||||
logo_alpha=0.2, ))
|
||||
logo_alpha=0.2,
|
||||
)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
# output_html = await create_cognee_style_network_with_logo(
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ def test_prepare_nodes():
|
|||
assert isinstance(nodes_df, pd.DataFrame)
|
||||
assert len(nodes_df) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cognee_style_network_with_logo():
|
||||
import networkx as nx
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ async def main(enable_steps):
|
|||
print(format_triplets(results))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
setup_logging(logging.ERROR)
|
||||
|
||||
rebuild_kg = True
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue