import asyncio import os import json from datetime import datetime from typing import Dict, List, Tuple, Any, Optional, Set import webbrowser from pathlib import Path import cognee from cognee.shared.logging_utils import get_logger # Configure logger logger = get_logger(name="enhanced_graph_visualization") # Type aliases for clarity NodeData = Dict[str, Any] EdgeData = Dict[str, Any] GraphData = Tuple[List[Tuple[Any, Dict]], List[Tuple[Any, Any, Optional[str], Dict]]] class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder to handle datetime objects.""" def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return super().default(obj) class NodeSetVisualizer: """Class to create enhanced visualizations for NodeSet data in knowledge graphs.""" # Color mapping for different node types NODE_COLORS = { "Entity": "#f47710", "EntityType": "#6510f4", "DocumentChunk": "#801212", "TextDocument": "#a83232", # Darker red for documents "TextSummary": "#1077f4", "NodeSet": "#ff00ff", # Bright magenta for NodeSet nodes "Unknown": "#999999", "default": "#D3D3D3", } # Size mapping for different node types NODE_SIZES = { "NodeSet": 20, # Larger size for NodeSet nodes "TextDocument": 18, # Larger size for document nodes "DocumentChunk": 18, # Larger size for document nodes "TextSummary": 16, # Medium size for TextSummary nodes "default": 13, # Default size } def __init__(self): """Initialize the visualizer.""" self.graph_engine = None self.nodes_data = [] self.edges_data = [] self.node_count = 0 self.edge_count = 0 self.nodeset_count = 0 async def get_graph_data(self) -> bool: """Fetch graph data from the graph engine. Returns: bool: True if data was successfully retrieved, False otherwise. """ self.graph_engine = await cognee.infrastructure.databases.graph.get_graph_engine() # Check if the graph exists and has nodes self.node_count = len(self.graph_engine.graph.nodes()) self.edge_count = len(self.graph_engine.graph.edges()) logger.info(f"Graph contains {self.node_count} nodes and {self.edge_count} edges") print(f"Graph contains {self.node_count} nodes and {self.edge_count} edges") if self.node_count == 0: logger.error("The graph is empty! Please run a test script first to generate data.") print("ERROR: The graph is empty! Please run a test script first to generate data.") return False graph_data = await self.graph_engine.get_graph_data() self.nodes_data, self.edges_data = graph_data # Count NodeSets for status display self.nodeset_count = sum(1 for _, info in self.nodes_data if info.get("type") == "NodeSet") return True def prepare_node_data(self) -> List[NodeData]: """Process raw node data to prepare for visualization. Returns: List[NodeData]: List of prepared node data objects. """ nodes_list = [] # Create a lookup for node types for faster access node_type_lookup = {str(node_id): node_info.get("type", "Unknown") for node_id, node_info in self.nodes_data} for node_id, node_info in self.nodes_data: # Create a clean copy to avoid modifying the original processed_node = node_info.copy() # Remove fields that cause JSON serialization issues self._clean_node_data(processed_node) # Add required visualization properties processed_node["id"] = str(node_id) node_type = processed_node.get("type", "default") # Apply visual styling based on node type processed_node["color"] = self.NODE_COLORS.get(node_type, self.NODE_COLORS["default"]) processed_node["size"] = self.NODE_SIZES.get(node_type, self.NODE_SIZES["default"]) # Create display names self._format_node_display_name(processed_node, node_type) nodes_list.append(processed_node) return nodes_list @staticmethod def _clean_node_data(node: NodeData) -> None: """Remove fields that might cause JSON serialization issues. Args: node: The node data to clean """ # Remove non-essential fields that might cause serialization issues for key in ["created_at", "updated_at", "raw_data_location"]: if key in node: del node[key] @staticmethod def _format_node_display_name(node: NodeData, node_type: str) -> None: """Format the display name for a node. Args: node: The node data to process node_type: The type of the node """ # Set a default name if none exists node["name"] = node.get("name", node.get("id", "Unknown")) # Special formatting for NodeSet nodes if node_type == "NodeSet" and "node_id" in node: node["display_name"] = f"NodeSet: {node['node_id']}" else: node["display_name"] = node["name"] # Truncate long display names if len(node["display_name"]) > 30: node["display_name"] = f"{node['display_name'][:27]}..." def prepare_edge_data(self, nodes_list: List[NodeData]) -> List[EdgeData]: """Process raw edge data to prepare for visualization. Args: nodes_list: The processed node data Returns: List[EdgeData]: List of prepared edge data objects. """ links_list = [] # Create a lookup for node types for faster access node_type_lookup = {node["id"]: node.get("type", "Unknown") for node in nodes_list} for source, target, relation, edge_info in self.edges_data: source_str = str(source) target_str = str(target) # Skip if source or target not in node_type_lookup (should not happen) if source_str not in node_type_lookup or target_str not in node_type_lookup: continue # Get node types source_type = node_type_lookup[source_str] target_type = node_type_lookup[target_str] # Create edge data link_data = { "source": source_str, "target": target_str, "relation": relation or "UNKNOWN" } # Categorize the edge for styling link_data["connection_type"] = self._categorize_edge(source_type, target_type) links_list.append(link_data) return links_list @staticmethod def _categorize_edge(source_type: str, target_type: str) -> str: """Categorize an edge based on the connected node types. Args: source_type: The type of the source node target_type: The type of the target node Returns: str: The category of the edge """ if source_type == "NodeSet" and target_type != "NodeSet": return "nodeset_to_value" elif (source_type in ["TextDocument", "DocumentChunk", "TextSummary"]) and target_type == "NodeSet": return "document_to_nodeset" elif target_type == "NodeSet": return "to_nodeset" elif source_type == "NodeSet": return "from_nodeset" else: return "standard" def generate_html(self, nodes_list: List[NodeData], links_list: List[EdgeData]) -> str: """Generate the HTML visualization with D3.js. Args: nodes_list: The processed node data links_list: The processed edge data Returns: str: The HTML content for the visualization """ # Use embedded template directly - more reliable than file access html_template = self._get_embedded_html_template() # Generate the HTML content with custom JSON encoder for datetime objects html_content = html_template.replace("{nodes}", json.dumps(nodes_list, cls=DateTimeEncoder)) html_content = html_content.replace("{links}", json.dumps(links_list, cls=DateTimeEncoder)) html_content = html_content.replace("{node_count}", str(self.node_count)) html_content = html_content.replace("{edge_count}", str(self.edge_count)) html_content = html_content.replace("{nodeset_count}", str(self.nodeset_count)) return html_content def save_html(self, html_content: str) -> str: """Save the HTML content to a file and open it in the browser. Args: html_content: The HTML content to save Returns: str: The path to the saved file """ # Create the output file path output_path = Path.cwd() / "enhanced_nodeset_visualization.html" # Write the HTML content to the file with open(output_path, "w") as f: f.write(html_content) logger.info(f"Enhanced visualization saved to: {output_path}") print(f"Enhanced visualization saved to: {output_path}") # Open the visualization in the default web browser file_url = f"file://{output_path}" logger.info(f"Opening visualization in browser: {file_url}") print(f"Opening enhanced visualization in browser: {file_url}") webbrowser.open(file_url) return str(output_path) @staticmethod def _get_embedded_html_template() -> str: """Get the embedded HTML template as a fallback. Returns: str: The HTML template """ return """ Cognee NodeSet Visualization

Node Types

NodeSet
TextDocument
DocumentChunk
TextSummary
Entity
EntityType
Unknown

Edge Types

Document → NodeSet
NodeSet → Value
Any → NodeSet
Standard Connection
Nodes: {node_count}
Edges: {edge_count}
NodeSets: {nodeset_count}
""" async def create_visualization(self) -> Optional[str]: """Main method to create the visualization. Returns: Optional[str]: Path to the saved visualization file, or None if unsuccessful """ print("Creating enhanced NodeSet visualization...") # Get graph data if not await self.get_graph_data(): return None try: # Process nodes nodes_list = self.prepare_node_data() # Process edges links_list = self.prepare_edge_data(nodes_list) # Generate HTML html_content = self.generate_html(nodes_list, links_list) # Save to file and open in browser return self.save_html(html_content) except Exception as e: logger.error(f"Error creating visualization: {str(e)}") print(f"Error creating visualization: {str(e)}") return None async def main(): """Main entry point for the script.""" visualizer = NodeSetVisualizer() await visualizer.create_visualization() if __name__ == "__main__": # Run the async main function loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(main()) finally: loop.close()