Fix visualization
This commit is contained in:
parent
0ff9ffa11b
commit
41b1486cff
9 changed files with 38 additions and 41 deletions
|
|
@ -24,7 +24,7 @@ from cognee.api.v1.users.routers import (
|
||||||
get_reset_password_router,
|
get_reset_password_router,
|
||||||
get_verify_router,
|
get_verify_router,
|
||||||
get_users_router,
|
get_users_router,
|
||||||
get_visualize_router
|
get_visualize_router,
|
||||||
)
|
)
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ async def run_cognify_pipeline(
|
||||||
summarization_model=cognee_config.summarization_model,
|
summarization_model=cognee_config.summarization_model,
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": 10},
|
||||||
),
|
),
|
||||||
Task(add_data_points, only_root = True, task_config = { "batch_size": 10 }),
|
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
|
||||||
]
|
]
|
||||||
|
|
||||||
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
||||||
|
|
|
||||||
|
|
@ -11,21 +11,22 @@ from cognee.modules.users.methods import get_authenticated_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_visualize_router() -> APIRouter:
|
def get_visualize_router() -> APIRouter:
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post("/", response_model=None)
|
@router.post("/", response_model=None)
|
||||||
async def visualize(
|
async def visualize(
|
||||||
user: User = Depends(get_authenticated_user),
|
user: User = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
"""This endpoint is responsible for adding data to the graph."""
|
"""This endpoint is responsible for adding data to the graph."""
|
||||||
from cognee.api.v1.visualize import visualize_graph
|
from cognee.api.v1.visualize import visualize_graph
|
||||||
|
|
||||||
try:
|
try:
|
||||||
html_visualization = await visualize_graph()
|
html_visualization = await visualize_graph()
|
||||||
return html_visualization
|
return html_visualization
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
async def visualize_graph(label: str = "name"):
|
||||||
async def visualize_graph(label:str="name"):
|
|
||||||
""" """
|
""" """
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
graph_data = await graph_engine.get_graph_data()
|
graph_data = await graph_engine.get_graph_data()
|
||||||
|
|
@ -12,4 +11,4 @@ async def visualize_graph(label:str="name"):
|
||||||
|
|
||||||
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
|
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
"""Adapter for Generic API LLM provider API"""
|
"""Adapter for Generic API LLM provider API"""
|
||||||
|
|
||||||
|
|
@ -27,8 +28,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
|
|
||||||
|
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
|
@ -49,8 +49,6 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
max_retries=5,
|
max_retries=5,
|
||||||
api_base = self.endpoint,
|
api_base=self.endpoint,
|
||||||
response_model=response_model
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -305,7 +305,6 @@ async def convert_to_serializable_graph(G):
|
||||||
return new_G
|
return new_G
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_layout_positions(G, layout_func, layout_scale):
|
def generate_layout_positions(G, layout_func, layout_scale):
|
||||||
"""
|
"""
|
||||||
Generate layout positions for the graph using the specified layout function.
|
Generate layout positions for the graph using the specified layout function.
|
||||||
|
|
@ -381,7 +380,7 @@ async def create_cognee_style_network_with_logo(
|
||||||
layout_func=nx.spring_layout,
|
layout_func=nx.spring_layout,
|
||||||
layout_scale=3.0,
|
layout_scale=3.0,
|
||||||
logo_alpha=0.1,
|
logo_alpha=0.1,
|
||||||
bokeh_object = False,
|
bokeh_object=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a Cognee-inspired network visualization with an embedded logo.
|
Create a Cognee-inspired network visualization with an embedded logo.
|
||||||
|
|
@ -389,7 +388,6 @@ async def create_cognee_style_network_with_logo(
|
||||||
logging.info("Converting graph to serializable format...")
|
logging.info("Converting graph to serializable format...")
|
||||||
G = await convert_to_serializable_graph(G)
|
G = await convert_to_serializable_graph(G)
|
||||||
|
|
||||||
|
|
||||||
logging.info("Generating layout positions...")
|
logging.info("Generating layout positions...")
|
||||||
layout_positions = generate_layout_positions(G, layout_func, layout_scale)
|
layout_positions = generate_layout_positions(G, layout_func, layout_scale)
|
||||||
|
|
||||||
|
|
@ -420,7 +418,6 @@ async def create_cognee_style_network_with_logo(
|
||||||
embed_logo(p, layout_scale, logo_alpha, "bottom_right")
|
embed_logo(p, layout_scale, logo_alpha, "bottom_right")
|
||||||
embed_logo(p, layout_scale, logo_alpha, "top_left")
|
embed_logo(p, layout_scale, logo_alpha, "top_left")
|
||||||
|
|
||||||
|
|
||||||
logging.info("Styling and rendering graph...")
|
logging.info("Styling and rendering graph...")
|
||||||
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
|
style_and_render_graph(p, G, layout_positions, label, node_colors, centrality)
|
||||||
|
|
||||||
|
|
@ -434,7 +431,6 @@ async def create_cognee_style_network_with_logo(
|
||||||
)
|
)
|
||||||
p.add_tools(hover_tool)
|
p.add_tools(hover_tool)
|
||||||
|
|
||||||
|
|
||||||
logging.info(f"Saving visualization to {output_filename}...")
|
logging.info(f"Saving visualization to {output_filename}...")
|
||||||
html_content = file_html(p, CDN, title)
|
html_content = file_html(p, CDN, title)
|
||||||
with open(output_filename, "w") as f:
|
with open(output_filename, "w") as f:
|
||||||
|
|
@ -442,8 +438,6 @@ async def create_cognee_style_network_with_logo(
|
||||||
|
|
||||||
logging.info("Visualization complete.")
|
logging.info("Visualization complete.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if bokeh_object:
|
if bokeh_object:
|
||||||
return p
|
return p
|
||||||
return html_content
|
return html_content
|
||||||
|
|
@ -461,9 +455,8 @@ def graph_to_tuple(graph):
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(log_level=logging.INFO):
|
def setup_logging(log_level=logging.INFO):
|
||||||
""" This method sets up the logging configuration. """
|
"""This method sets up the logging configuration."""
|
||||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n")
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n")
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
stream_handler.setFormatter(formatter)
|
stream_handler.setFormatter(formatter)
|
||||||
|
|
@ -498,12 +491,17 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
output_html = asyncio.run(create_cognee_style_network_with_logo(G, output_filename="example_network.html",
|
output_html = asyncio.run(
|
||||||
title="Example Cognee Network",
|
create_cognee_style_network_with_logo(
|
||||||
node_attribute="group", # Attribute to use for coloring nodes
|
G,
|
||||||
layout_func=nx.spring_layout, # Layout function
|
output_filename="example_network.html",
|
||||||
layout_scale=3.0, # Scale for the layout
|
title="Example Cognee Network",
|
||||||
logo_alpha=0.2, ))
|
node_attribute="group", # Attribute to use for coloring nodes
|
||||||
|
layout_func=nx.spring_layout, # Layout function
|
||||||
|
layout_scale=3.0, # Scale for the layout
|
||||||
|
logo_alpha=0.2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Call the function
|
# Call the function
|
||||||
# output_html = await create_cognee_style_network_with_logo(
|
# output_html = await create_cognee_style_network_with_logo(
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ async def add_data_points(data_points: list[DataPoint], only_root=False):
|
||||||
added_nodes=added_nodes,
|
added_nodes=added_nodes,
|
||||||
added_edges=added_edges,
|
added_edges=added_edges,
|
||||||
visited_properties=visited_properties,
|
visited_properties=visited_properties,
|
||||||
only_root = only_root,
|
only_root=only_root,
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ def test_prepare_nodes():
|
||||||
assert isinstance(nodes_df, pd.DataFrame)
|
assert isinstance(nodes_df, pd.DataFrame)
|
||||||
assert len(nodes_df) == 1
|
assert len(nodes_df) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_cognee_style_network_with_logo():
|
async def test_create_cognee_style_network_with_logo():
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ async def main(enable_steps):
|
||||||
print(format_triplets(results))
|
print(format_triplets(results))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
setup_logging(logging.ERROR)
|
setup_logging(logging.ERROR)
|
||||||
|
|
||||||
rebuild_kg = True
|
rebuild_kg = True
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue