From 95c12fbc1e4d95e1ae4bb266ef25fe2cb05910bf Mon Sep 17 00:00:00 2001 From: vasilije Date: Mon, 7 Apr 2025 20:33:34 +0200 Subject: [PATCH] Fix nodesets --- cognee/api/v1/add/add.py | 4 +- cognee/api/v1/cognify/cognify.py | 2 +- cognee/examples/layered_graph_db_example.py | 350 --------- cognee/examples/node_set_test.py | 146 ++++ .../simplified_layered_graph_example.py | 201 ------ cognee/tasks/node_set/apply_node_set.py | 493 +++++++++++-- enhanced_nodeset_visualization.py | 675 ++++++++++++++++++ 7 files changed, 1252 insertions(+), 619 deletions(-) delete mode 100644 cognee/examples/layered_graph_db_example.py create mode 100644 cognee/examples/node_set_test.py delete mode 100644 cognee/examples/simplified_layered_graph_example.py create mode 100644 enhanced_nodeset_visualization.py diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 8bfc37ac9..9ca4d3c86 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -16,7 +16,7 @@ async def add( data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None, - NodeSet: Optional[List[str]] = None, + node_set: Optional[List[str]] = None, ): # Create tables for databases await create_relational_db_and_tables() @@ -37,7 +37,7 @@ async def add( if user is None: user = await get_default_user() - tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)] + tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, node_set)] dataset_id = uuid5(NAMESPACE_OID, dataset_name) pipeline = run_tasks( diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index e28b91f8f..0c4487648 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -140,7 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's task_config={"batch_size": 10}, ), Task(add_data_points, task_config={"batch_size": 10}), - Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints + Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values and create set nodes ] return default_tasks diff --git a/cognee/examples/layered_graph_db_example.py b/cognee/examples/layered_graph_db_example.py deleted file mode 100644 index 1522fe040..000000000 --- a/cognee/examples/layered_graph_db_example.py +++ /dev/null @@ -1,350 +0,0 @@ -""" -Example demonstrating how to use LayeredKnowledgeGraphDP with database adapters. - -This example shows how to: -1. Create a layered knowledge graph -2. Set a database adapter -3. Add nodes, edges, and layers with automatic persistence to the database -4. Retrieve graph data from the database -""" - -import asyncio -import uuid -import logging -import json -from uuid import UUID - -from cognee.modules.graph.datapoint_layered_graph import ( - GraphNode, - GraphEdge, - GraphLayer, - LayeredKnowledgeGraphDP, -) -from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter -from cognee.infrastructure.databases.graph import get_graph_engine - -# Set up logging -logging.basicConfig( - level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - - -async def retrieve_graph_manually(graph_id, adapter): - """Retrieve a graph manually from the NetworkX adapter.""" - graph_db = adapter._graph_db - if not hasattr(graph_db, "graph") or not graph_db.graph: - await graph_db.load_graph_from_file() - - graph_id_str = str(graph_id) - logger.info(f"Looking for graph with ID: {graph_id_str}") - - if hasattr(graph_db, "graph") and graph_db.graph.has_node(graph_id_str): - # Get the graph node data - graph_data = graph_db.graph.nodes[graph_id_str] - logger.info(f"Found graph node data: {graph_data}") - - # Create the graph instance - graph = LayeredKnowledgeGraphDP( - id=graph_id, - name=graph_data.get("name", ""), - description=graph_data.get("description", ""), - metadata=graph_data.get("metadata", {}), - ) - - # Set the adapter - graph.set_adapter(adapter) - - # Find and add all layers, nodes, and edges - nx_graph = graph_db.graph - - # Find all layers for this graph - logger.info("Finding layers connected to the graph") - found_layers = set() - for source, target, key in nx_graph.edges(graph_id_str, keys=True): - if key == "CONTAINS_LAYER": - # Found layer - layer_data = nx_graph.nodes[target] - layer_id_str = target - logger.info(f"Found layer: {layer_id_str} with data: {layer_data}") - - # Convert parent layers - parent_layers = [] - if "parent_layers" in layer_data: - try: - if isinstance(layer_data["parent_layers"], str): - import json - - parent_layers = [ - UUID(p) for p in json.loads(layer_data["parent_layers"]) - ] - elif isinstance(layer_data["parent_layers"], list): - parent_layers = [UUID(str(p)) for p in layer_data["parent_layers"]] - except Exception as e: - logger.error(f"Error processing parent layers: {e}") - parent_layers = [] - - # Create layer - try: - layer = GraphLayer( - id=UUID(layer_id_str), - name=layer_data.get("name", ""), - description=layer_data.get("description", ""), - layer_type=layer_data.get("layer_type", "default"), - parent_layers=parent_layers, - properties=layer_data.get("properties", {}), - metadata=layer_data.get("metadata", {}), - ) - graph.layers[layer.id] = layer - found_layers.add(layer_id_str) - except Exception as e: - logger.error(f"Error creating layer object: {e}") - - # Helper function to safely get UUID - def safe_uuid(value): - if isinstance(value, UUID): - return value - try: - return UUID(str(value)) - except Exception as e: - logger.error(f"Error converting to UUID: {value} - {e}") - return None - - # Find all nodes for this graph - logger.info("Finding nodes in the layers") - for node_id, node_data in nx_graph.nodes(data=True): - # First check if this is a node by its metadata - if node_data.get("metadata", {}).get("type") == "GraphNode": - # Get the layer ID from the node data - if "layer_id" in node_data: - layer_id_str = node_data["layer_id"] - # Check if this layer ID is in our found layers - try: - layer_id = safe_uuid(layer_id_str) - if layer_id and layer_id in graph.layers: - logger.info(f"Found node with ID {node_id} in layer {layer_id_str}") - - # Create the node - node = GraphNode( - id=safe_uuid(node_id), - name=node_data.get("name", ""), - node_type=node_data.get("node_type", ""), - description=node_data.get("description", ""), - properties=node_data.get("properties", {}), - layer_id=layer_id, - metadata=node_data.get("metadata", {}), - ) - graph.nodes[node.id] = node - graph.node_layer_map[node.id] = layer_id - except Exception as e: - logger.error(f"Error processing node {node_id}: {e}") - - # Find all edges for this graph - logger.info("Finding edges in the layers") - for node_id, node_data in nx_graph.nodes(data=True): - # First check if this is an edge by its metadata - if node_data.get("metadata", {}).get("type") == "GraphEdge": - # Get the layer ID from the edge data - if "layer_id" in node_data: - layer_id_str = node_data["layer_id"] - # Check if this layer ID is in our found layers - try: - layer_id = safe_uuid(layer_id_str) - if layer_id and layer_id in graph.layers: - source_id = safe_uuid(node_data.get("source_node_id")) - target_id = safe_uuid(node_data.get("target_node_id")) - - if source_id and target_id: - logger.info(f"Found edge with ID {node_id} in layer {layer_id_str}") - - # Create the edge - edge = GraphEdge( - id=safe_uuid(node_id), - source_node_id=source_id, - target_node_id=target_id, - relationship_name=node_data.get("relationship_name", ""), - properties=node_data.get("properties", {}), - layer_id=layer_id, - metadata=node_data.get("metadata", {}), - ) - graph.edges[edge.id] = edge - graph.edge_layer_map[edge.id] = layer_id - except Exception as e: - logger.error(f"Error processing edge {node_id}: {e}") - - logger.info( - f"Manually retrieved graph with {len(graph.layers)} layers, {len(graph.nodes)} nodes, and {len(graph.edges)} edges" - ) - return graph - else: - logger.error(f"Could not find graph with ID {graph_id}") - return None - - -def get_nodes_and_edges_from_graph(graph, layer_id): - """ - Get nodes and edges from a layer in the graph, avoiding database calls. - - Args: - graph: The LayeredKnowledgeGraphDP instance - layer_id: The UUID of the layer - - Returns: - Tuple of (nodes, edges) lists - """ - nodes = [node for node in graph.nodes.values() if node.layer_id == layer_id] - edges = [edge for edge in graph.edges.values() if edge.layer_id == layer_id] - return nodes, edges - - -async def main(): - logger.info("Starting layered graph database example") - - # Get the default graph engine (typically NetworkXAdapter) - graph_db = await get_graph_engine() - logger.info(f"Using graph database adapter: {type(graph_db).__name__}") - - # Create an adapter using the graph engine - adapter = LayeredGraphDBAdapter(graph_db) - - # Create a new empty graph - graph = LayeredKnowledgeGraphDP.create_empty( - name="Example Database Graph", - description="A graph that persists to the database", - metadata={ - "type": "LayeredKnowledgeGraph", - "index_fields": ["name"], - }, # Ensure proper metadata - ) - logger.info(f"Created graph with ID: {graph.id}") - - # Set the adapter for this graph - graph.set_adapter(adapter) - - # Create and add a base layer - base_layer = GraphLayer.create( - name="Base Layer", description="The foundation layer of the graph", layer_type="base" - ) - graph.add_layer(base_layer) - logger.info(f"Added base layer with ID: {base_layer.id}") - - # Create and add a derived layer that extends the base layer - derived_layer = GraphLayer.create( - name="Derived Layer", - description="A layer that extends the base layer", - layer_type="derived", - parent_layers=[base_layer.id], - ) - graph.add_layer(derived_layer) - logger.info(f"Added derived layer with ID: {derived_layer.id}") - - # Create and add nodes to the base layer - node1 = GraphNode.create( - name="Concept A", node_type="concept", description="A foundational concept" - ) - graph.add_node(node1, base_layer.id) - logger.info(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}") - - node2 = GraphNode.create( - name="Concept B", node_type="concept", description="Another foundational concept" - ) - graph.add_node(node2, base_layer.id) - logger.info(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}") - - # Create and add a node to the derived layer - node3 = GraphNode.create( - name="Derived Concept", - node_type="concept", - description="A concept derived from foundational concepts", - ) - graph.add_node(node3, derived_layer.id) - logger.info(f"Added node3 with ID: {node3.id} to layer: {derived_layer.id}") - - # Create and add edges between nodes - edge1 = GraphEdge.create( - source_node_id=node1.id, target_node_id=node2.id, relationship_name="RELATES_TO" - ) - graph.add_edge(edge1, base_layer.id) - logger.info(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}") - - edge2 = GraphEdge.create( - source_node_id=node1.id, target_node_id=node3.id, relationship_name="EXPANDS_TO" - ) - graph.add_edge(edge2, derived_layer.id) - logger.info(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}") - - edge3 = GraphEdge.create( - source_node_id=node2.id, target_node_id=node3.id, relationship_name="CONTRIBUTES_TO" - ) - graph.add_edge(edge3, derived_layer.id) - logger.info(f"Added edge3 with ID: {edge3.id} between {node2.id} and {node3.id}") - - # Save the graph state to a file for NetworkXAdapter - if hasattr(graph_db, "save_graph_to_file"): - logger.info(f"Saving graph to file: {getattr(graph_db, 'filename', 'unknown')}") - await graph_db.save_graph_to_file() - - # Persist the entire graph to the database - # This is optional since the graph is already being persisted incrementally - # when add_layer, add_node, and add_edge are called - logger.info("Persisting entire graph to database") - graph_id = await graph.persist() - logger.info(f"Graph persisted with ID: {graph_id}") - - # Check if the graph exists in the database - if hasattr(graph_db, "graph"): - logger.info(f"Checking if graph exists in memory: {graph_db.graph.has_node(str(graph.id))}") - - # Check the node data - if graph_db.graph.has_node(str(graph.id)): - node_data = graph_db.graph.nodes[str(graph.id)] - logger.info(f"Graph node data: {node_data}") - - # List all nodes in the graph - logger.info("Nodes in the graph:") - for node_id, node_data in graph_db.graph.nodes(data=True): - logger.info( - f" Node {node_id}: type={node_data.get('metadata', {}).get('type', 'unknown')}" - ) - - # Try to retrieve the graph using our from_database method - retrieved_graph = None - try: - logger.info(f"Retrieving graph from database with ID: {graph.id}") - retrieved_graph = await LayeredKnowledgeGraphDP.from_database(graph.id, adapter) - logger.info(f"Retrieved graph: {retrieved_graph}") - logger.info( - f"Retrieved {len(retrieved_graph.layers)} layers, {len(retrieved_graph.nodes)} nodes, and {len(retrieved_graph.edges)} edges" - ) - except ValueError as e: - logger.error(f"Error retrieving graph using from_database: {str(e)}") - - # Try using manual retrieval as a fallback - logger.info("Trying manual retrieval as a fallback") - retrieved_graph = await retrieve_graph_manually(graph.id, adapter) - - if retrieved_graph: - logger.info(f"Successfully retrieved graph manually: {retrieved_graph}") - else: - logger.error("Failed to retrieve graph manually") - return - - # Use our helper function to get nodes and edges - if retrieved_graph: - # Get nodes in the base layer - base_nodes, base_edges = get_nodes_and_edges_from_graph(retrieved_graph, base_layer.id) - logger.info(f"Nodes in base layer: {[node.name for node in base_nodes]}") - logger.info(f"Edges in base layer: {[edge.relationship_name for edge in base_edges]}") - - # Get nodes in the derived layer - derived_nodes, derived_edges = get_nodes_and_edges_from_graph( - retrieved_graph, derived_layer.id - ) - logger.info(f"Nodes in derived layer: {[node.name for node in derived_nodes]}") - logger.info(f"Edges in derived layer: {[edge.relationship_name for edge in derived_edges]}") - else: - logger.error("No graph was retrieved, cannot display nodes and edges") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/examples/node_set_test.py b/cognee/examples/node_set_test.py new file mode 100644 index 000000000..986a27e1e --- /dev/null +++ b/cognee/examples/node_set_test.py @@ -0,0 +1,146 @@ +import os +import uuid +import logging +import json +from typing import List + +import httpx +import asyncio +import cognee +from cognee.api.v1.search import SearchType +from cognee.shared.logging_utils import get_logger, ERROR + +# Replace incorrect import with proper cognee.prune module +# from cognee.infrastructure.data_storage import reset_cognee + +# Remove incorrect imports and use high-level API +# from cognee.services.cognify import cognify +# from cognee.services.search import search +from cognee.tasks.node_set.apply_node_set import apply_node_set +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Data + +# Configure logging to see detailed output +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Set up httpx client for debugging +httpx.Client(timeout=None) + + +async def generate_test_node_ids(count: int) -> List[str]: + """Generate unique node IDs for testing.""" + return [str(uuid.uuid4()) for _ in range(count)] + + +async def add_text_with_nodeset(text: str, node_set: List[str], text_id: str = None) -> str: + """Add text with NodeSet to cognee.""" + # Generate text ID if not provided - but note we can't directly set the ID + if text_id is None: + text_id = str(uuid.uuid4()) + + # Print NodeSet details + logger.info(f"Adding text with NodeSet") + logger.info(f"NodeSet for this text: {node_set}") + + # Use high-level cognee.add() with the correct parameters + await cognee.add(text, node_set=node_set) + logger.info(f"Saved text with NodeSet to database") + + return text_id # Note: we can't control the actual ID generated by cognee.add() + + +async def check_data_records(): + """Check for data records in the database to verify NodeSet storage.""" + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + from sqlalchemy import select + query = select(Data) + result = await session.execute(query) + records = result.scalars().all() + + logger.info(f"Found {len(records)} records in the database") + for record in records: + logger.info(f"Record ID: {record.id}, name: {record.name}, node_set: {record.node_set}") + # Print raw_data_location to see where the text content is stored + logger.info(f"Raw data location: {record.raw_data_location}") + + +async def run_simple_node_set_test(): + """Run a simple test of NodeSet functionality.""" + try: + # Reset cognee data using correct module + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + logger.info("Cognee data reset complete") + + # Generate test node IDs for two NodeSets + # nodeset1 = await generate_test_node_ids(3) + nodeset2 = await generate_test_node_ids(3) + nodeset1 = ["test","horse", "mamamama"] + + logger.info(f"Created test node IDs for NodeSet 1: {nodeset1}") + logger.info(f"Created test node IDs for NodeSet 2: {nodeset2}") + + # Add two texts with NodeSets + text1 = "Natural Language Processing (NLP) is a field of artificial intelligence focused on enabling computers to understand and process human language. It involves techniques for language analysis, translation, and generation." + text2 = "Artificial Intelligence (AI) systems are designed to perform tasks that typically require human intelligence. These tasks include speech recognition, decision-making, and language translation." + + text1_id = await add_text_with_nodeset(text1, nodeset1) + text2_id = await add_text_with_nodeset(text2, nodeset2) + + # Verify that the NodeSets were stored correctly + await check_data_records() + + # Run the cognify process to create a knowledge graph + pipeline_run_id = await cognee.cognify() + logger.info(f"Cognify process completed with pipeline run ID: {pipeline_run_id}") + + # Skip graph search as the NetworkXAdapter doesn't support Cypher + logger.info("Skipping graph search as NetworkXAdapter doesn't support Cypher") + + # Search for insights related to the added texts + search_query = "NLP and AI" + logger.info(f"Searching for insights with query: '{search_query}'") + search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=search_query) + + # Extract NodeSet information from search results + logger.info(f"Found {len(search_results)} search results for '{search_query}':") + for i, result in enumerate(search_results): + logger.info(f"Result {i+1} text: {getattr(result, 'text', 'No text')}") + + # Check for NodeSet and SetNodeId + node_set = getattr(result, "NodeSet", None) + set_node_id = getattr(result, "SetNodeId", None) + logger.info(f"Result {i+1} - NodeSet: {node_set}, SetNodeId: {set_node_id}") + + # Check id and type + logger.info(f"Result {i+1} - ID: {getattr(result, 'id', 'No ID')}, Type: {getattr(result, 'type', 'No type')}") + + # Check if this is a document chunk and has is_part_of property + if hasattr(result, "is_part_of") and result.is_part_of: + logger.info(f"Result {i+1} is a document chunk with parent ID: {result.is_part_of.id}") + + # Check if the parent has a NodeSet + parent_has_nodeset = hasattr(result.is_part_of, "NodeSet") and result.is_part_of.NodeSet + logger.info(f" Parent has NodeSet: {parent_has_nodeset}") + if parent_has_nodeset: + logger.info(f" Parent NodeSet: {result.is_part_of.NodeSet}") + + # Print all attributes of the result to see what's available + logger.info(f"Result {i+1} - All attributes: {dir(result)}") + + except Exception as e: + logger.error(f"Error in simple NodeSet test: {e}") + raise + + +if __name__ == "__main__": + logger = get_logger(level=ERROR) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(run_simple_node_set_test()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) \ No newline at end of file diff --git a/cognee/examples/simplified_layered_graph_example.py b/cognee/examples/simplified_layered_graph_example.py deleted file mode 100644 index 9ffb3644b..000000000 --- a/cognee/examples/simplified_layered_graph_example.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Example demonstrating how to use the simplified LayeredKnowledgeGraph with database adapters. - -This example shows how to: -1. Create a layered knowledge graph -2. Add nodes, edges, and layers -3. Retrieve layer data -4. Work with cumulative layers -""" - -import asyncio -import logging -import uuid -from uuid import UUID -import os - -from cognee.modules.graph.simplified_layered_graph import ( - LayeredKnowledgeGraph, - GraphNode, - GraphEdge, - GraphLayer, -) -from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter -from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter -from cognee.infrastructure.databases.graph import get_graph_engine - -# Set up logging -logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") -logger = logging.getLogger(__name__) - - -async def main(): - print("Starting simplified layered graph example") - - # Initialize file path for the NetworkXAdapter - db_dir = os.path.join(os.path.expanduser("~"), "cognee/cognee/.cognee_system/databases") - os.makedirs(db_dir, exist_ok=True) - db_file = os.path.join(db_dir, "cognee_graph.pkl") - - # Use NetworkXAdapter for the graph database - adapter = NetworkXAdapter(filename=db_file) - - # Initialize the adapter by creating or loading the graph - if not os.path.exists(db_file): - await adapter.create_empty_graph(db_file) - await adapter.load_graph_from_file() - - print(f"Using graph database adapter: {adapter.__class__.__name__}") - - # Create an empty graph - graph = LayeredKnowledgeGraph.create_empty("Test Graph") - graph.set_adapter(LayeredGraphDBAdapter(adapter)) - print(f"Created graph with ID: {graph.id}") - - # Create layers - base_layer = await graph.add_layer( - name="Base Layer", description="The foundation layer with base concepts", layer_type="base" - ) - print(f"Added base layer with ID: {base_layer.id}") - - intermediate_layer = await graph.add_layer( - name="Intermediate Layer", - description="Layer that builds on base concepts", - layer_type="intermediate", - parent_layers=[base_layer.id], - ) - print(f"Added intermediate layer with ID: {intermediate_layer.id}") - - derived_layer = await graph.add_layer( - name="Derived Layer", - description="Final layer with derived concepts", - layer_type="derived", - parent_layers=[intermediate_layer.id], - ) - print(f"Added derived layer with ID: {derived_layer.id}") - - # Add nodes to layers - node1 = await graph.add_node( - name="Base Concept A", - node_type="concept", - properties={"importance": "high"}, - metadata={"source": "example"}, - layer_id=base_layer.id, - ) - print(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}") - - node2 = await graph.add_node( - name="Base Concept B", - node_type="concept", - properties={"importance": "medium"}, - metadata={"source": "example"}, - layer_id=base_layer.id, - ) - print(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}") - - node3 = await graph.add_node( - name="Intermediate Concept", - node_type="concept", - properties={"derived_from": ["Base Concept A"]}, - metadata={"source": "example"}, - layer_id=intermediate_layer.id, - ) - print(f"Added node3 with ID: {node3.id} to layer: {intermediate_layer.id}") - - node4 = await graph.add_node( - name="Derived Concept", - node_type="concept", - properties={"derived_from": ["Intermediate Concept"]}, - metadata={"source": "example"}, - layer_id=derived_layer.id, - ) - print(f"Added node4 with ID: {node4.id} to layer: {derived_layer.id}") - - # Add edges between nodes - edge1 = await graph.add_edge( - source_id=node1.id, - target_id=node2.id, - edge_type="RELATES_TO", - properties={"strength": "high"}, - metadata={"source": "example"}, - layer_id=base_layer.id, - ) - print(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}") - - edge2 = await graph.add_edge( - source_id=node1.id, - target_id=node3.id, - edge_type="SUPPORTS", - properties={"confidence": 0.9}, - metadata={"source": "example"}, - layer_id=intermediate_layer.id, - ) - print(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}") - - edge3 = await graph.add_edge( - source_id=node3.id, - target_id=node4.id, - edge_type="EXTENDS", - properties={"confidence": 0.8}, - metadata={"source": "example"}, - layer_id=derived_layer.id, - ) - print(f"Added edge3 with ID: {edge3.id} between {node3.id} and {node4.id}") - - # Save the graph to the database - # The graph is automatically saved when nodes and edges are added, - # but for NetworkXAdapter we'll save the file explicitly - if hasattr(adapter, "save_graph_to_file"): - await adapter.save_graph_to_file(adapter.filename) - print(f"Saving graph to file: {adapter.filename}") - - # Retrieve all layers - layers = await graph.get_layers() - print(f"Retrieved {len(layers)} layers") - - # Load the graph from the database - print(f"Loading graph with ID: {graph.id} from database") - # Create a new graph instance from the database - loaded_graph = LayeredKnowledgeGraph(id=graph.id, name="Test Graph") - loaded_graph.set_adapter(LayeredGraphDBAdapter(adapter)) - # Load layers, which will also load nodes and edges - loaded_layers = await loaded_graph.get_layers() - print(f"Successfully loaded graph: {loaded_graph} with {len(loaded_layers)} layers") - - # Display contents of each layer - print("\n===== Individual Layer Contents =====") - for layer in layers: - # Get nodes and edges in the layer - nodes = await graph.get_nodes_in_layer(layer.id) - edges = await graph.get_edges_in_layer(layer.id) - - # Print summary - print(f"Nodes in {layer.name.lower()} layer: {[node.name for node in nodes]}") - print(f"Edges in {layer.name.lower()} layer: {[edge.edge_type for edge in edges]}") - - # Display cumulative layer views - print("\n===== Cumulative Layer Views =====") - - # Intermediate layer - should include base layer nodes/edges - print("\nCumulative layer graph for intermediate layer:") - int_nodes, int_edges = await graph.get_cumulative_layer_graph(intermediate_layer.id) - print(f"Intermediate cumulative nodes: {[node.name for node in int_nodes]}") - print(f"Intermediate cumulative edges: {[edge.edge_type for edge in int_edges]}") - - # Derived layer - should include all nodes/edges - print("\nCumulative layer graph for derived layer:") - derived_nodes, derived_edges = await graph.get_cumulative_layer_graph(derived_layer.id) - print(f"Derived cumulative nodes: {[node.name for node in derived_nodes]}") - print(f"Derived cumulative edges: {[edge.edge_type for edge in derived_edges]}") - - # Test helper methods - print("\n===== Helper Method Results =====") - base_nodes = await graph.get_nodes_in_layer(base_layer.id) - base_edges = await graph.get_edges_in_layer(base_layer.id) - print(f"Base layer contains {len(base_nodes)} nodes and {len(base_edges)} edges") - - print("Example complete") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tasks/node_set/apply_node_set.py b/cognee/tasks/node_set/apply_node_set.py index facb67d46..8ac8156ed 100644 --- a/cognee/tasks/node_set/apply_node_set.py +++ b/cognee/tasks/node_set/apply_node_set.py @@ -1,88 +1,451 @@ +import uuid import json import logging -from sqlalchemy import select -from typing import List, Any +from typing import List, Dict, Any, Optional, Union, Tuple, Sequence, Protocol, Callable +from contextlib import asynccontextmanager + +from cognee.shared.logging_utils import get_logger +from sqlalchemy.future import select +from sqlalchemy import or_ from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter from cognee.modules.data.models import Data from cognee.infrastructure.engine.models.DataPoint import DataPoint -logger = logging.getLogger(__name__) +# Configure logger +logger = get_logger(name="apply_node_set") -async def apply_node_set(data_points: List[DataPoint]) -> List[DataPoint]: - """ - Apply NodeSet values from the relational store to DataPoint instances. - - This task fetches the NodeSet values from the Data model in the relational database - and applies them to the corresponding DataPoint instances. - +async def apply_node_set(data: Union[DataPoint, List[DataPoint]]) -> Union[DataPoint, List[DataPoint]]: + """Apply NodeSet values to DataPoint objects. + Args: - data_points: List of DataPoint instances to process - + data: Single DataPoint or list of DataPoints to process + Returns: - List of updated DataPoint instances with NodeSet values applied + The processed DataPoint(s) with updated NodeSet values """ - logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints") + if not data: + logger.warning("No data provided to apply NodeSet values") + return data - if not data_points: + # Convert single DataPoint to list for uniform processing + data_points = data if isinstance(data, list) else [data] + logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints") + + # Process DataPoint objects to apply NodeSet values + updated_data_points = await _process_data_points(data_points) + + # Create set nodes for each NodeSet + await _create_set_nodes(updated_data_points) + + # Return data in the same format it was received + return data_points if isinstance(data, list) else data_points[0] + + +async def _process_data_points(data_points: List[DataPoint]) -> List[DataPoint]: + """Process DataPoint objects to apply NodeSet values from the database. + + Args: + data_points: List of DataPoint objects to process + + Returns: + The processed list of DataPoints with updated NodeSet values + """ + try: + if not data_points: + return [] + + # Extract IDs and collect document relationships + data_point_ids, parent_doc_map = _collect_ids_and_relationships(data_points) + + # Get NodeSet values from database + nodeset_map = await _fetch_nodesets_from_database(data_point_ids, parent_doc_map) + + # Apply NodeSet values to DataPoints + _apply_nodesets_to_datapoints(data_points, nodeset_map, parent_doc_map) + + return data_points + + except Exception as e: + logger.error(f"Error processing DataPoints: {str(e)}") return data_points - # Create a map of data_point IDs for efficient lookup - data_point_map = {str(dp.id): dp for dp in data_points} - - # Get the database engine - db_engine = get_relational_engine() - - # Get session (handles both sync and async cases for testing) - session = db_engine.get_async_session() - - try: - # Handle both AsyncMock and actual async context manager for testing - if hasattr(session, "__aenter__"): - # It's a real async context manager - async with session as sess: - await _process_data_points(sess, data_point_map) - else: - # It's an AsyncMock in tests - await _process_data_points(session, data_point_map) - - except Exception as e: - logger.error(f"Error applying NodeSet values: {e}") - - return data_points - - -async def _process_data_points(session, data_point_map): - """ - Process data points with the given session. - - This helper function handles the actual database query and NodeSet application. +def _collect_ids_and_relationships(data_points: List[DataPoint]) -> Tuple[List[str], Dict[str, str]]: + """Extract DataPoint IDs and document relationships. + Args: - session: Database session - data_point_map: Map of data point IDs to DataPoint objects + data_points: List of DataPoint objects + + Returns: + Tuple containing: + - List of DataPoint IDs + - Dictionary mapping DataPoint IDs to parent document IDs """ - # Get all data points from the Data table that have node_set values - # and correspond to the data_points we're processing - data_ids = list(data_point_map.keys()) + data_point_ids = [] + parent_doc_ids = [] + parent_doc_map = {} + + # Collect all IDs to look up + for dp in data_points: + # Get the DataPoint ID + if hasattr(dp, "id"): + dp_id = str(dp.id) + data_point_ids.append(dp_id) + + # Check if there's a parent document to get NodeSet from + if (hasattr(dp, "made_from") and + hasattr(dp.made_from, "is_part_of") and + hasattr(dp.made_from.is_part_of, "id")): + + parent_id = str(dp.made_from.is_part_of.id) + parent_doc_ids.append(parent_id) + parent_doc_map[dp_id] = parent_id + + logger.info(f"Found {len(data_point_ids)} DataPoint IDs and {len(parent_doc_ids)} parent document IDs") + + # Combine all IDs for database lookup + return data_point_ids + parent_doc_ids, parent_doc_map - query = select(Data).where(Data.id.in_(data_ids), Data.node_set.is_not(None)) - result = await session.execute(query) - data_records = result.scalars().all() +async def _fetch_nodesets_from_database(ids: List[str], parent_doc_map: Dict[str, str]) -> Dict[str, Any]: + """Fetch NodeSet values from the database for the given IDs. + + Args: + ids: List of IDs to search for + parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs + + Returns: + Dictionary mapping document IDs to their NodeSet values + """ + # Convert string IDs to UUIDs for database lookup + uuid_objects = _convert_ids_to_uuids(ids) + if not uuid_objects: + return {} + + # Query the database for NodeSet values + nodeset_map = {} + + db_engine = get_relational_engine() + async with db_engine.get_async_session() as sess: + # Query for records matching the IDs + query = select(Data).where(Data.id.in_(uuid_objects)) + result = await sess.execute(query) + records = result.scalars().all() + + logger.info(f"Found {len(records)} total records matching the IDs") + + # Extract NodeSet values from records + for record in records: + if record.node_set is not None: + nodeset_map[str(record.id)] = record.node_set + + logger.info(f"Found {len(nodeset_map)} records with non-NULL NodeSet values") + + return nodeset_map - # Apply NodeSet values to corresponding DataPoint instances - for data_record in data_records: - data_point_id = str(data_record.id) - if data_point_id in data_point_map and data_record.node_set: - # Parse the JSON string to get the NodeSet - try: - node_set = json.loads(data_record.node_set) - data_point_map[data_point_id].NodeSet = node_set - logger.debug(f"Applied NodeSet {node_set} to DataPoint {data_point_id}") - except json.JSONDecodeError: - logger.warning(f"Failed to parse NodeSet JSON for DataPoint {data_point_id}") + +def _convert_ids_to_uuids(ids: List[str]) -> List[uuid.UUID]: + """Convert string IDs to UUID objects. + + Args: + ids: List of string IDs to convert + + Returns: + List of UUID objects + """ + uuid_objects = [] + for id_str in ids: + try: + uuid_objects.append(uuid.UUID(id_str)) + except Exception as e: + logger.warning(f"Failed to convert ID {id_str} to UUID: {str(e)}") + + logger.info(f"Converted {len(uuid_objects)} out of {len(ids)} IDs to UUID objects") + return uuid_objects + + +def _apply_nodesets_to_datapoints( + data_points: List[DataPoint], + nodeset_map: Dict[str, Any], + parent_doc_map: Dict[str, str] +) -> None: + """Apply NodeSet values to DataPoints. + + Args: + data_points: List of DataPoint objects to update + nodeset_map: Dictionary mapping document IDs to their NodeSet values + parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs + """ + for dp in data_points: + dp_id = str(dp.id) + + # Try direct match first + if dp_id in nodeset_map: + nodeset = nodeset_map[dp_id] + logger.info(f"Found NodeSet for {dp_id}: {nodeset}") + dp.NodeSet = nodeset + + # Then try parent document + elif dp_id in parent_doc_map and parent_doc_map[dp_id] in nodeset_map: + parent_id = parent_doc_map[dp_id] + nodeset = nodeset_map[parent_id] + logger.info(f"Found NodeSet from parent document {parent_id} for {dp_id}: {nodeset}") + dp.NodeSet = nodeset + + +async def _create_set_nodes(data_points: List[DataPoint]) -> None: + """Create set nodes for DataPoints with NodeSets. + + Args: + data_points: List of DataPoint objects to process + """ + try: + logger.info(f"Creating set nodes for {len(data_points)} DataPoints") + + for dp in data_points: + if not hasattr(dp, "NodeSet") or not dp.NodeSet: continue + + try: + # Create set nodes for the NodeSet (one per value) + document_id = str(dp.id) if hasattr(dp, "id") else None + set_node_ids, edge_ids = await create_set_node(dp.NodeSet, document_id=document_id) + + if set_node_ids and len(set_node_ids) > 0: + logger.info(f"Created {len(set_node_ids)} set nodes for NodeSet with {len(dp.NodeSet)} values") + + # Store the set node IDs with the DataPoint if possible + try: + # Store as JSON string if multiple IDs, or single ID if only one + if len(set_node_ids) > 1: + dp.SetNodeIds = json.dumps(set_node_ids) + else: + dp.SetNodeId = set_node_ids[0] + except Exception as e: + logger.warning(f"Failed to store set node IDs for NodeSet: {str(e)}") + else: + logger.warning("Failed to create set nodes for NodeSet") + except Exception as e: + logger.error(f"Error creating set nodes: {str(e)}") + except Exception as e: + logger.error(f"Error processing NodeSets: {str(e)}") - logger.info(f"Successfully applied NodeSet values to DataPoints") + +async def create_set_node( + nodes: Union[List[str], str], + name: Optional[str] = None, + document_id: Optional[str] = None +) -> Tuple[Optional[List[str]], List[str]]: + """Create individual nodes for each value in the NodeSet. + + Args: + nodes: List of node values or JSON string representing node values + name: Base name for the NodeSet (optional) + document_id: ID of the document containing the NodeSet (optional) + + Returns: + Tuple containing: + - List of created set node IDs (or None if creation failed) + - List of created edge IDs + """ + try: + if not nodes: + logger.warning("No nodes provided to create set nodes") + return None, [] + + # Get the graph engine + graph_engine = await get_graph_engine() + + # Parse nodes if provided as JSON string + nodes_list = _parse_nodes_input(nodes) + if not nodes_list: + return None, [] + + # Base name for the set if not provided + base_name = name or f"NodeSet_{uuid.uuid4().hex[:8]}" + logger.info(f"Creating individual set nodes for {len(nodes_list)} values with base name '{base_name}'") + + # Create set nodes using the appropriate strategy + return await _create_set_nodes_unified(graph_engine, nodes_list, base_name, document_id) + + except Exception as e: + logger.error(f"Failed to create set nodes: {str(e)}") + return None, [] + + +def _parse_nodes_input(nodes: Union[List[str], str]) -> Optional[List[str]]: + """Parse the nodes input. + + Args: + nodes: List of node values or JSON string representing node values + + Returns: + List of node values or None if parsing failed + """ + if isinstance(nodes, str): + try: + parsed_nodes = json.loads(nodes) + logger.info(f"Parsed nodes string into list with {len(parsed_nodes)} items") + return parsed_nodes + except Exception as e: + logger.error(f"Failed to parse nodes as JSON: {str(e)}") + return None + return nodes + + +async def _create_set_nodes_unified( + graph_engine: Any, + nodes_list: List[str], + base_name: str, + document_id: Optional[str] +) -> Tuple[List[str], List[str]]: + """Create set nodes using either NetworkX or generic graph engine. + + Args: + graph_engine: The graph engine instance + nodes_list: List of node values + base_name: Base name for the NodeSet + document_id: ID of the document containing the NodeSet (optional) + + Returns: + Tuple containing: + - List of created set node IDs + - List of created edge IDs + """ + all_set_node_ids = [] + all_edge_ids = [] + + # Define strategies for node and edge creation based on graph engine type + if isinstance(graph_engine, NetworkXAdapter): + # NetworkX-specific strategy + async def create_node(node_value: str) -> str: + set_node_id = str(uuid.uuid4()) + node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}" + + graph_engine.graph.add_node( + set_node_id, + id=set_node_id, + type="NodeSet", + name=node_name, + node_id=node_value + ) + + # Validate node creation + if set_node_id in graph_engine.graph.nodes(): + node_props = dict(graph_engine.graph.nodes[set_node_id]) + logger.info(f"Created set node for value '{node_value}': {json.dumps(node_props)}") + else: + logger.warning(f"Node {set_node_id} not found in graph after adding") + + return set_node_id + + async def create_value_edge(set_node_id: str, node_value: str) -> List[str]: + edge_ids = [] + try: + edge_id = str(uuid.uuid4()) + graph_engine.graph.add_edge( + set_node_id, + node_value, + id=edge_id, + type="CONTAINS" + ) + edge_ids.append(edge_id) + except Exception as e: + logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}") + return edge_ids + + async def create_document_edge(document_id: str, set_node_id: str) -> List[str]: + edge_ids = [] + try: + doc_to_nodeset_id = str(uuid.uuid4()) + graph_engine.graph.add_edge( + document_id, + set_node_id, + id=doc_to_nodeset_id, + type="HAS_NODESET" + ) + edge_ids.append(doc_to_nodeset_id) + logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}") + except Exception as e: + logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}") + return edge_ids + + # Finalize function for NetworkX + async def finalize() -> None: + await graph_engine.save_graph_to_file(graph_engine.filename) + + else: + # Generic graph engine strategy + async def create_node(node_value: str) -> str: + node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}" + set_node_props = { + "name": node_name, + "type": "NodeSet", + "node_id": node_value + } + return await graph_engine.create_node(set_node_props) + + async def create_value_edge(set_node_id: str, node_value: str) -> List[str]: + edge_ids = [] + try: + edge_id = await graph_engine.create_edge( + source_id=set_node_id, + target_id=node_value, + edge_type="CONTAINS" + ) + edge_ids.append(edge_id) + except Exception as e: + logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}") + return edge_ids + + async def create_document_edge(document_id: str, set_node_id: str) -> List[str]: + edge_ids = [] + try: + doc_to_nodeset_id = await graph_engine.create_edge( + source_id=document_id, + target_id=set_node_id, + edge_type="HAS_NODESET" + ) + edge_ids.append(doc_to_nodeset_id) + logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}") + except Exception as e: + logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}") + return edge_ids + + # Finalize function for generic engine (no-op) + async def finalize() -> None: + pass + + # Unified process for both engine types + for node_value in nodes_list: + try: + # Create the node + set_node_id = await create_node(node_value) + + # Create edges to the value + value_edge_ids = await create_value_edge(set_node_id, node_value) + all_edge_ids.extend(value_edge_ids) + + # Create edges to the document if provided + if document_id: + doc_edge_ids = await create_document_edge(document_id, set_node_id) + all_edge_ids.extend(doc_edge_ids) + + all_set_node_ids.append(set_node_id) + except Exception as e: + logger.error(f"Failed to create set node for value '{node_value}': {str(e)}") + + # Finalize the process + await finalize() + + # Return results + if all_set_node_ids: + logger.info(f"Created {len(all_set_node_ids)} individual set nodes with values: {nodes_list}") + return all_set_node_ids, all_edge_ids + else: + logger.error("Failed to create any set nodes") + return [], [] diff --git a/enhanced_nodeset_visualization.py b/enhanced_nodeset_visualization.py new file mode 100644 index 000000000..f4bace409 --- /dev/null +++ b/enhanced_nodeset_visualization.py @@ -0,0 +1,675 @@ +import asyncio +import os +import json +from datetime import datetime +from typing import Dict, List, Tuple, Any, Optional, Set +import webbrowser +from pathlib import Path + +import cognee +from cognee.shared.logging_utils import get_logger + +# Configure logger +logger = get_logger(name="enhanced_graph_visualization") + +# Type aliases for clarity +NodeData = Dict[str, Any] +EdgeData = Dict[str, Any] +GraphData = Tuple[List[Tuple[Any, Dict]], List[Tuple[Any, Any, Optional[str], Dict]]] + + +class DateTimeEncoder(json.JSONEncoder): + """Custom JSON encoder to handle datetime objects.""" + + def default(self, obj): + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + +class NodeSetVisualizer: + """Class to create enhanced visualizations for NodeSet data in knowledge graphs.""" + + # Color mapping for different node types + NODE_COLORS = { + "Entity": "#f47710", + "EntityType": "#6510f4", + "DocumentChunk": "#801212", + "TextDocument": "#a83232", # Darker red for documents + "TextSummary": "#1077f4", + "NodeSet": "#ff00ff", # Bright magenta for NodeSet nodes + "Unknown": "#999999", + "default": "#D3D3D3", + } + + # Size mapping for different node types + NODE_SIZES = { + "NodeSet": 20, # Larger size for NodeSet nodes + "TextDocument": 18, # Larger size for document nodes + "DocumentChunk": 18, # Larger size for document nodes + "TextSummary": 16, # Medium size for TextSummary nodes + "default": 13, # Default size + } + + def __init__(self): + """Initialize the visualizer.""" + self.graph_engine = None + self.nodes_data = [] + self.edges_data = [] + self.node_count = 0 + self.edge_count = 0 + self.nodeset_count = 0 + + async def get_graph_data(self) -> bool: + """Fetch graph data from the graph engine. + + Returns: + bool: True if data was successfully retrieved, False otherwise. + """ + self.graph_engine = await cognee.infrastructure.databases.graph.get_graph_engine() + + # Check if the graph exists and has nodes + self.node_count = len(self.graph_engine.graph.nodes()) + self.edge_count = len(self.graph_engine.graph.edges()) + logger.info(f"Graph contains {self.node_count} nodes and {self.edge_count} edges") + print(f"Graph contains {self.node_count} nodes and {self.edge_count} edges") + + if self.node_count == 0: + logger.error("The graph is empty! Please run a test script first to generate data.") + print("ERROR: The graph is empty! Please run a test script first to generate data.") + return False + + graph_data = await self.graph_engine.get_graph_data() + self.nodes_data, self.edges_data = graph_data + + # Count NodeSets for status display + self.nodeset_count = sum(1 for _, info in self.nodes_data if info.get("type") == "NodeSet") + + return True + + def prepare_node_data(self) -> List[NodeData]: + """Process raw node data to prepare for visualization. + + Returns: + List[NodeData]: List of prepared node data objects. + """ + nodes_list = [] + + # Create a lookup for node types for faster access + node_type_lookup = {str(node_id): node_info.get("type", "Unknown") + for node_id, node_info in self.nodes_data} + + for node_id, node_info in self.nodes_data: + # Create a clean copy to avoid modifying the original + processed_node = node_info.copy() + + # Remove fields that cause JSON serialization issues + self._clean_node_data(processed_node) + + # Add required visualization properties + processed_node["id"] = str(node_id) + node_type = processed_node.get("type", "default") + + # Apply visual styling based on node type + processed_node["color"] = self.NODE_COLORS.get(node_type, self.NODE_COLORS["default"]) + processed_node["size"] = self.NODE_SIZES.get(node_type, self.NODE_SIZES["default"]) + + # Create display names + self._format_node_display_name(processed_node, node_type) + + nodes_list.append(processed_node) + + return nodes_list + + @staticmethod + def _clean_node_data(node: NodeData) -> None: + """Remove fields that might cause JSON serialization issues. + + Args: + node: The node data to clean + """ + # Remove non-essential fields that might cause serialization issues + for key in ["created_at", "updated_at", "raw_data_location"]: + if key in node: + del node[key] + + @staticmethod + def _format_node_display_name(node: NodeData, node_type: str) -> None: + """Format the display name for a node. + + Args: + node: The node data to process + node_type: The type of the node + """ + # Set a default name if none exists + node["name"] = node.get("name", node.get("id", "Unknown")) + + # Special formatting for NodeSet nodes + if node_type == "NodeSet" and "node_id" in node: + node["display_name"] = f"NodeSet: {node['node_id']}" + else: + node["display_name"] = node["name"] + + # Truncate long display names + if len(node["display_name"]) > 30: + node["display_name"] = f"{node['display_name'][:27]}..." + + def prepare_edge_data(self, nodes_list: List[NodeData]) -> List[EdgeData]: + """Process raw edge data to prepare for visualization. + + Args: + nodes_list: The processed node data + + Returns: + List[EdgeData]: List of prepared edge data objects. + """ + links_list = [] + + # Create a lookup for node types for faster access + node_type_lookup = {node["id"]: node.get("type", "Unknown") for node in nodes_list} + + for source, target, relation, edge_info in self.edges_data: + source_str = str(source) + target_str = str(target) + + # Skip if source or target not in node_type_lookup (should not happen) + if source_str not in node_type_lookup or target_str not in node_type_lookup: + continue + + # Get node types + source_type = node_type_lookup[source_str] + target_type = node_type_lookup[target_str] + + # Create edge data + link_data = { + "source": source_str, + "target": target_str, + "relation": relation or "UNKNOWN" + } + + # Categorize the edge for styling + link_data["connection_type"] = self._categorize_edge(source_type, target_type) + + links_list.append(link_data) + + return links_list + + @staticmethod + def _categorize_edge(source_type: str, target_type: str) -> str: + """Categorize an edge based on the connected node types. + + Args: + source_type: The type of the source node + target_type: The type of the target node + + Returns: + str: The category of the edge + """ + if source_type == "NodeSet" and target_type != "NodeSet": + return "nodeset_to_value" + elif (source_type in ["TextDocument", "DocumentChunk", "TextSummary"]) and target_type == "NodeSet": + return "document_to_nodeset" + elif target_type == "NodeSet": + return "to_nodeset" + elif source_type == "NodeSet": + return "from_nodeset" + else: + return "standard" + + def generate_html(self, nodes_list: List[NodeData], links_list: List[EdgeData]) -> str: + """Generate the HTML visualization with D3.js. + + Args: + nodes_list: The processed node data + links_list: The processed edge data + + Returns: + str: The HTML content for the visualization + """ + # Use embedded template directly - more reliable than file access + html_template = self._get_embedded_html_template() + + # Generate the HTML content with custom JSON encoder for datetime objects + html_content = html_template.replace("{nodes}", json.dumps(nodes_list, cls=DateTimeEncoder)) + html_content = html_content.replace("{links}", json.dumps(links_list, cls=DateTimeEncoder)) + html_content = html_content.replace("{node_count}", str(self.node_count)) + html_content = html_content.replace("{edge_count}", str(self.edge_count)) + html_content = html_content.replace("{nodeset_count}", str(self.nodeset_count)) + + return html_content + + def save_html(self, html_content: str) -> str: + """Save the HTML content to a file and open it in the browser. + + Args: + html_content: The HTML content to save + + Returns: + str: The path to the saved file + """ + # Create the output file path + output_path = Path.cwd() / "enhanced_nodeset_visualization.html" + + # Write the HTML content to the file + with open(output_path, "w") as f: + f.write(html_content) + + logger.info(f"Enhanced visualization saved to: {output_path}") + print(f"Enhanced visualization saved to: {output_path}") + + # Open the visualization in the default web browser + file_url = f"file://{output_path}" + logger.info(f"Opening visualization in browser: {file_url}") + print(f"Opening enhanced visualization in browser: {file_url}") + webbrowser.open(file_url) + + return str(output_path) + + @staticmethod + def _get_embedded_html_template() -> str: + """Get the embedded HTML template as a fallback. + + Returns: + str: The HTML template + """ + return """ + + + + + Cognee NodeSet Visualization + + + + + + + +
+

Node Types

+
+
+
NodeSet
+
+
+
+
TextDocument
+
+
+
+
DocumentChunk
+
+
+
+
TextSummary
+
+
+
+
Entity
+
+
+
+
EntityType
+
+
+
+
Unknown
+
+
+ + +
+

Edge Types

+
+
+
Document → NodeSet
+
+
+
+
NodeSet → Value
+
+
+
+
Any → NodeSet
+
+
+
+
Standard Connection
+
+
+ + +
+ + + + +
+ + +
+
Nodes: {node_count}
+
Edges: {edge_count}
+
NodeSets: {nodeset_count}
+
+ + + + + """ + + async def create_visualization(self) -> Optional[str]: + """Main method to create the visualization. + + Returns: + Optional[str]: Path to the saved visualization file, or None if unsuccessful + """ + print("Creating enhanced NodeSet visualization...") + + # Get graph data + if not await self.get_graph_data(): + return None + + try: + # Process nodes + nodes_list = self.prepare_node_data() + + # Process edges + links_list = self.prepare_edge_data(nodes_list) + + # Generate HTML + html_content = self.generate_html(nodes_list, links_list) + + # Save to file and open in browser + return self.save_html(html_content) + + except Exception as e: + logger.error(f"Error creating visualization: {str(e)}") + print(f"Error creating visualization: {str(e)}") + return None + + +async def main(): + """Main entry point for the script.""" + visualizer = NodeSetVisualizer() + await visualizer.create_visualization() + + +if __name__ == "__main__": + # Run the async main function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.close() \ No newline at end of file