Fix nodesets

This commit is contained in:
vasilije 2025-04-07 20:33:34 +02:00
parent 2355d1bfea
commit 95c12fbc1e
7 changed files with 1252 additions and 619 deletions

View file

@ -16,7 +16,7 @@ async def add(
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
dataset_name: str = "main_dataset",
user: User = None,
NodeSet: Optional[List[str]] = None,
node_set: Optional[List[str]] = None,
):
# Create tables for databases
await create_relational_db_and_tables()
@ -37,7 +37,7 @@ async def add(
if user is None:
user = await get_default_user()
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)]
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, node_set)]
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
pipeline = run_tasks(

View file

@ -140,7 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values and create set nodes
]
return default_tasks

View file

@ -1,350 +0,0 @@
"""
Example demonstrating how to use LayeredKnowledgeGraphDP with database adapters.
This example shows how to:
1. Create a layered knowledge graph
2. Set a database adapter
3. Add nodes, edges, and layers with automatic persistence to the database
4. Retrieve graph data from the database
"""
import asyncio
import uuid
import logging
import json
from uuid import UUID
from cognee.modules.graph.datapoint_layered_graph import (
GraphNode,
GraphEdge,
GraphLayer,
LayeredKnowledgeGraphDP,
)
from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter
from cognee.infrastructure.databases.graph import get_graph_engine
# Set up logging
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
async def retrieve_graph_manually(graph_id, adapter):
"""Retrieve a graph manually from the NetworkX adapter."""
graph_db = adapter._graph_db
if not hasattr(graph_db, "graph") or not graph_db.graph:
await graph_db.load_graph_from_file()
graph_id_str = str(graph_id)
logger.info(f"Looking for graph with ID: {graph_id_str}")
if hasattr(graph_db, "graph") and graph_db.graph.has_node(graph_id_str):
# Get the graph node data
graph_data = graph_db.graph.nodes[graph_id_str]
logger.info(f"Found graph node data: {graph_data}")
# Create the graph instance
graph = LayeredKnowledgeGraphDP(
id=graph_id,
name=graph_data.get("name", ""),
description=graph_data.get("description", ""),
metadata=graph_data.get("metadata", {}),
)
# Set the adapter
graph.set_adapter(adapter)
# Find and add all layers, nodes, and edges
nx_graph = graph_db.graph
# Find all layers for this graph
logger.info("Finding layers connected to the graph")
found_layers = set()
for source, target, key in nx_graph.edges(graph_id_str, keys=True):
if key == "CONTAINS_LAYER":
# Found layer
layer_data = nx_graph.nodes[target]
layer_id_str = target
logger.info(f"Found layer: {layer_id_str} with data: {layer_data}")
# Convert parent layers
parent_layers = []
if "parent_layers" in layer_data:
try:
if isinstance(layer_data["parent_layers"], str):
import json
parent_layers = [
UUID(p) for p in json.loads(layer_data["parent_layers"])
]
elif isinstance(layer_data["parent_layers"], list):
parent_layers = [UUID(str(p)) for p in layer_data["parent_layers"]]
except Exception as e:
logger.error(f"Error processing parent layers: {e}")
parent_layers = []
# Create layer
try:
layer = GraphLayer(
id=UUID(layer_id_str),
name=layer_data.get("name", ""),
description=layer_data.get("description", ""),
layer_type=layer_data.get("layer_type", "default"),
parent_layers=parent_layers,
properties=layer_data.get("properties", {}),
metadata=layer_data.get("metadata", {}),
)
graph.layers[layer.id] = layer
found_layers.add(layer_id_str)
except Exception as e:
logger.error(f"Error creating layer object: {e}")
# Helper function to safely get UUID
def safe_uuid(value):
if isinstance(value, UUID):
return value
try:
return UUID(str(value))
except Exception as e:
logger.error(f"Error converting to UUID: {value} - {e}")
return None
# Find all nodes for this graph
logger.info("Finding nodes in the layers")
for node_id, node_data in nx_graph.nodes(data=True):
# First check if this is a node by its metadata
if node_data.get("metadata", {}).get("type") == "GraphNode":
# Get the layer ID from the node data
if "layer_id" in node_data:
layer_id_str = node_data["layer_id"]
# Check if this layer ID is in our found layers
try:
layer_id = safe_uuid(layer_id_str)
if layer_id and layer_id in graph.layers:
logger.info(f"Found node with ID {node_id} in layer {layer_id_str}")
# Create the node
node = GraphNode(
id=safe_uuid(node_id),
name=node_data.get("name", ""),
node_type=node_data.get("node_type", ""),
description=node_data.get("description", ""),
properties=node_data.get("properties", {}),
layer_id=layer_id,
metadata=node_data.get("metadata", {}),
)
graph.nodes[node.id] = node
graph.node_layer_map[node.id] = layer_id
except Exception as e:
logger.error(f"Error processing node {node_id}: {e}")
# Find all edges for this graph
logger.info("Finding edges in the layers")
for node_id, node_data in nx_graph.nodes(data=True):
# First check if this is an edge by its metadata
if node_data.get("metadata", {}).get("type") == "GraphEdge":
# Get the layer ID from the edge data
if "layer_id" in node_data:
layer_id_str = node_data["layer_id"]
# Check if this layer ID is in our found layers
try:
layer_id = safe_uuid(layer_id_str)
if layer_id and layer_id in graph.layers:
source_id = safe_uuid(node_data.get("source_node_id"))
target_id = safe_uuid(node_data.get("target_node_id"))
if source_id and target_id:
logger.info(f"Found edge with ID {node_id} in layer {layer_id_str}")
# Create the edge
edge = GraphEdge(
id=safe_uuid(node_id),
source_node_id=source_id,
target_node_id=target_id,
relationship_name=node_data.get("relationship_name", ""),
properties=node_data.get("properties", {}),
layer_id=layer_id,
metadata=node_data.get("metadata", {}),
)
graph.edges[edge.id] = edge
graph.edge_layer_map[edge.id] = layer_id
except Exception as e:
logger.error(f"Error processing edge {node_id}: {e}")
logger.info(
f"Manually retrieved graph with {len(graph.layers)} layers, {len(graph.nodes)} nodes, and {len(graph.edges)} edges"
)
return graph
else:
logger.error(f"Could not find graph with ID {graph_id}")
return None
def get_nodes_and_edges_from_graph(graph, layer_id):
"""
Get nodes and edges from a layer in the graph, avoiding database calls.
Args:
graph: The LayeredKnowledgeGraphDP instance
layer_id: The UUID of the layer
Returns:
Tuple of (nodes, edges) lists
"""
nodes = [node for node in graph.nodes.values() if node.layer_id == layer_id]
edges = [edge for edge in graph.edges.values() if edge.layer_id == layer_id]
return nodes, edges
async def main():
logger.info("Starting layered graph database example")
# Get the default graph engine (typically NetworkXAdapter)
graph_db = await get_graph_engine()
logger.info(f"Using graph database adapter: {type(graph_db).__name__}")
# Create an adapter using the graph engine
adapter = LayeredGraphDBAdapter(graph_db)
# Create a new empty graph
graph = LayeredKnowledgeGraphDP.create_empty(
name="Example Database Graph",
description="A graph that persists to the database",
metadata={
"type": "LayeredKnowledgeGraph",
"index_fields": ["name"],
}, # Ensure proper metadata
)
logger.info(f"Created graph with ID: {graph.id}")
# Set the adapter for this graph
graph.set_adapter(adapter)
# Create and add a base layer
base_layer = GraphLayer.create(
name="Base Layer", description="The foundation layer of the graph", layer_type="base"
)
graph.add_layer(base_layer)
logger.info(f"Added base layer with ID: {base_layer.id}")
# Create and add a derived layer that extends the base layer
derived_layer = GraphLayer.create(
name="Derived Layer",
description="A layer that extends the base layer",
layer_type="derived",
parent_layers=[base_layer.id],
)
graph.add_layer(derived_layer)
logger.info(f"Added derived layer with ID: {derived_layer.id}")
# Create and add nodes to the base layer
node1 = GraphNode.create(
name="Concept A", node_type="concept", description="A foundational concept"
)
graph.add_node(node1, base_layer.id)
logger.info(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}")
node2 = GraphNode.create(
name="Concept B", node_type="concept", description="Another foundational concept"
)
graph.add_node(node2, base_layer.id)
logger.info(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}")
# Create and add a node to the derived layer
node3 = GraphNode.create(
name="Derived Concept",
node_type="concept",
description="A concept derived from foundational concepts",
)
graph.add_node(node3, derived_layer.id)
logger.info(f"Added node3 with ID: {node3.id} to layer: {derived_layer.id}")
# Create and add edges between nodes
edge1 = GraphEdge.create(
source_node_id=node1.id, target_node_id=node2.id, relationship_name="RELATES_TO"
)
graph.add_edge(edge1, base_layer.id)
logger.info(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}")
edge2 = GraphEdge.create(
source_node_id=node1.id, target_node_id=node3.id, relationship_name="EXPANDS_TO"
)
graph.add_edge(edge2, derived_layer.id)
logger.info(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}")
edge3 = GraphEdge.create(
source_node_id=node2.id, target_node_id=node3.id, relationship_name="CONTRIBUTES_TO"
)
graph.add_edge(edge3, derived_layer.id)
logger.info(f"Added edge3 with ID: {edge3.id} between {node2.id} and {node3.id}")
# Save the graph state to a file for NetworkXAdapter
if hasattr(graph_db, "save_graph_to_file"):
logger.info(f"Saving graph to file: {getattr(graph_db, 'filename', 'unknown')}")
await graph_db.save_graph_to_file()
# Persist the entire graph to the database
# This is optional since the graph is already being persisted incrementally
# when add_layer, add_node, and add_edge are called
logger.info("Persisting entire graph to database")
graph_id = await graph.persist()
logger.info(f"Graph persisted with ID: {graph_id}")
# Check if the graph exists in the database
if hasattr(graph_db, "graph"):
logger.info(f"Checking if graph exists in memory: {graph_db.graph.has_node(str(graph.id))}")
# Check the node data
if graph_db.graph.has_node(str(graph.id)):
node_data = graph_db.graph.nodes[str(graph.id)]
logger.info(f"Graph node data: {node_data}")
# List all nodes in the graph
logger.info("Nodes in the graph:")
for node_id, node_data in graph_db.graph.nodes(data=True):
logger.info(
f" Node {node_id}: type={node_data.get('metadata', {}).get('type', 'unknown')}"
)
# Try to retrieve the graph using our from_database method
retrieved_graph = None
try:
logger.info(f"Retrieving graph from database with ID: {graph.id}")
retrieved_graph = await LayeredKnowledgeGraphDP.from_database(graph.id, adapter)
logger.info(f"Retrieved graph: {retrieved_graph}")
logger.info(
f"Retrieved {len(retrieved_graph.layers)} layers, {len(retrieved_graph.nodes)} nodes, and {len(retrieved_graph.edges)} edges"
)
except ValueError as e:
logger.error(f"Error retrieving graph using from_database: {str(e)}")
# Try using manual retrieval as a fallback
logger.info("Trying manual retrieval as a fallback")
retrieved_graph = await retrieve_graph_manually(graph.id, adapter)
if retrieved_graph:
logger.info(f"Successfully retrieved graph manually: {retrieved_graph}")
else:
logger.error("Failed to retrieve graph manually")
return
# Use our helper function to get nodes and edges
if retrieved_graph:
# Get nodes in the base layer
base_nodes, base_edges = get_nodes_and_edges_from_graph(retrieved_graph, base_layer.id)
logger.info(f"Nodes in base layer: {[node.name for node in base_nodes]}")
logger.info(f"Edges in base layer: {[edge.relationship_name for edge in base_edges]}")
# Get nodes in the derived layer
derived_nodes, derived_edges = get_nodes_and_edges_from_graph(
retrieved_graph, derived_layer.id
)
logger.info(f"Nodes in derived layer: {[node.name for node in derived_nodes]}")
logger.info(f"Edges in derived layer: {[edge.relationship_name for edge in derived_edges]}")
else:
logger.error("No graph was retrieved, cannot display nodes and edges")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,146 @@
import os
import uuid
import logging
import json
from typing import List
import httpx
import asyncio
import cognee
from cognee.api.v1.search import SearchType
from cognee.shared.logging_utils import get_logger, ERROR
# Replace incorrect import with proper cognee.prune module
# from cognee.infrastructure.data_storage import reset_cognee
# Remove incorrect imports and use high-level API
# from cognee.services.cognify import cognify
# from cognee.services.search import search
from cognee.tasks.node_set.apply_node_set import apply_node_set
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Data
# Configure logging to see detailed output
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up httpx client for debugging
httpx.Client(timeout=None)
async def generate_test_node_ids(count: int) -> List[str]:
"""Generate unique node IDs for testing."""
return [str(uuid.uuid4()) for _ in range(count)]
async def add_text_with_nodeset(text: str, node_set: List[str], text_id: str = None) -> str:
"""Add text with NodeSet to cognee."""
# Generate text ID if not provided - but note we can't directly set the ID
if text_id is None:
text_id = str(uuid.uuid4())
# Print NodeSet details
logger.info(f"Adding text with NodeSet")
logger.info(f"NodeSet for this text: {node_set}")
# Use high-level cognee.add() with the correct parameters
await cognee.add(text, node_set=node_set)
logger.info(f"Saved text with NodeSet to database")
return text_id # Note: we can't control the actual ID generated by cognee.add()
async def check_data_records():
"""Check for data records in the database to verify NodeSet storage."""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from sqlalchemy import select
query = select(Data)
result = await session.execute(query)
records = result.scalars().all()
logger.info(f"Found {len(records)} records in the database")
for record in records:
logger.info(f"Record ID: {record.id}, name: {record.name}, node_set: {record.node_set}")
# Print raw_data_location to see where the text content is stored
logger.info(f"Raw data location: {record.raw_data_location}")
async def run_simple_node_set_test():
"""Run a simple test of NodeSet functionality."""
try:
# Reset cognee data using correct module
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
logger.info("Cognee data reset complete")
# Generate test node IDs for two NodeSets
# nodeset1 = await generate_test_node_ids(3)
nodeset2 = await generate_test_node_ids(3)
nodeset1 = ["test","horse", "mamamama"]
logger.info(f"Created test node IDs for NodeSet 1: {nodeset1}")
logger.info(f"Created test node IDs for NodeSet 2: {nodeset2}")
# Add two texts with NodeSets
text1 = "Natural Language Processing (NLP) is a field of artificial intelligence focused on enabling computers to understand and process human language. It involves techniques for language analysis, translation, and generation."
text2 = "Artificial Intelligence (AI) systems are designed to perform tasks that typically require human intelligence. These tasks include speech recognition, decision-making, and language translation."
text1_id = await add_text_with_nodeset(text1, nodeset1)
text2_id = await add_text_with_nodeset(text2, nodeset2)
# Verify that the NodeSets were stored correctly
await check_data_records()
# Run the cognify process to create a knowledge graph
pipeline_run_id = await cognee.cognify()
logger.info(f"Cognify process completed with pipeline run ID: {pipeline_run_id}")
# Skip graph search as the NetworkXAdapter doesn't support Cypher
logger.info("Skipping graph search as NetworkXAdapter doesn't support Cypher")
# Search for insights related to the added texts
search_query = "NLP and AI"
logger.info(f"Searching for insights with query: '{search_query}'")
search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=search_query)
# Extract NodeSet information from search results
logger.info(f"Found {len(search_results)} search results for '{search_query}':")
for i, result in enumerate(search_results):
logger.info(f"Result {i+1} text: {getattr(result, 'text', 'No text')}")
# Check for NodeSet and SetNodeId
node_set = getattr(result, "NodeSet", None)
set_node_id = getattr(result, "SetNodeId", None)
logger.info(f"Result {i+1} - NodeSet: {node_set}, SetNodeId: {set_node_id}")
# Check id and type
logger.info(f"Result {i+1} - ID: {getattr(result, 'id', 'No ID')}, Type: {getattr(result, 'type', 'No type')}")
# Check if this is a document chunk and has is_part_of property
if hasattr(result, "is_part_of") and result.is_part_of:
logger.info(f"Result {i+1} is a document chunk with parent ID: {result.is_part_of.id}")
# Check if the parent has a NodeSet
parent_has_nodeset = hasattr(result.is_part_of, "NodeSet") and result.is_part_of.NodeSet
logger.info(f" Parent has NodeSet: {parent_has_nodeset}")
if parent_has_nodeset:
logger.info(f" Parent NodeSet: {result.is_part_of.NodeSet}")
# Print all attributes of the result to see what's available
logger.info(f"Result {i+1} - All attributes: {dir(result)}")
except Exception as e:
logger.error(f"Error in simple NodeSet test: {e}")
raise
if __name__ == "__main__":
logger = get_logger(level=ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_simple_node_set_test())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -1,201 +0,0 @@
"""
Example demonstrating how to use the simplified LayeredKnowledgeGraph with database adapters.
This example shows how to:
1. Create a layered knowledge graph
2. Add nodes, edges, and layers
3. Retrieve layer data
4. Work with cumulative layers
"""
import asyncio
import logging
import uuid
from uuid import UUID
import os
from cognee.modules.graph.simplified_layered_graph import (
LayeredKnowledgeGraph,
GraphNode,
GraphEdge,
GraphLayer,
)
from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.infrastructure.databases.graph import get_graph_engine
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
async def main():
print("Starting simplified layered graph example")
# Initialize file path for the NetworkXAdapter
db_dir = os.path.join(os.path.expanduser("~"), "cognee/cognee/.cognee_system/databases")
os.makedirs(db_dir, exist_ok=True)
db_file = os.path.join(db_dir, "cognee_graph.pkl")
# Use NetworkXAdapter for the graph database
adapter = NetworkXAdapter(filename=db_file)
# Initialize the adapter by creating or loading the graph
if not os.path.exists(db_file):
await adapter.create_empty_graph(db_file)
await adapter.load_graph_from_file()
print(f"Using graph database adapter: {adapter.__class__.__name__}")
# Create an empty graph
graph = LayeredKnowledgeGraph.create_empty("Test Graph")
graph.set_adapter(LayeredGraphDBAdapter(adapter))
print(f"Created graph with ID: {graph.id}")
# Create layers
base_layer = await graph.add_layer(
name="Base Layer", description="The foundation layer with base concepts", layer_type="base"
)
print(f"Added base layer with ID: {base_layer.id}")
intermediate_layer = await graph.add_layer(
name="Intermediate Layer",
description="Layer that builds on base concepts",
layer_type="intermediate",
parent_layers=[base_layer.id],
)
print(f"Added intermediate layer with ID: {intermediate_layer.id}")
derived_layer = await graph.add_layer(
name="Derived Layer",
description="Final layer with derived concepts",
layer_type="derived",
parent_layers=[intermediate_layer.id],
)
print(f"Added derived layer with ID: {derived_layer.id}")
# Add nodes to layers
node1 = await graph.add_node(
name="Base Concept A",
node_type="concept",
properties={"importance": "high"},
metadata={"source": "example"},
layer_id=base_layer.id,
)
print(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}")
node2 = await graph.add_node(
name="Base Concept B",
node_type="concept",
properties={"importance": "medium"},
metadata={"source": "example"},
layer_id=base_layer.id,
)
print(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}")
node3 = await graph.add_node(
name="Intermediate Concept",
node_type="concept",
properties={"derived_from": ["Base Concept A"]},
metadata={"source": "example"},
layer_id=intermediate_layer.id,
)
print(f"Added node3 with ID: {node3.id} to layer: {intermediate_layer.id}")
node4 = await graph.add_node(
name="Derived Concept",
node_type="concept",
properties={"derived_from": ["Intermediate Concept"]},
metadata={"source": "example"},
layer_id=derived_layer.id,
)
print(f"Added node4 with ID: {node4.id} to layer: {derived_layer.id}")
# Add edges between nodes
edge1 = await graph.add_edge(
source_id=node1.id,
target_id=node2.id,
edge_type="RELATES_TO",
properties={"strength": "high"},
metadata={"source": "example"},
layer_id=base_layer.id,
)
print(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}")
edge2 = await graph.add_edge(
source_id=node1.id,
target_id=node3.id,
edge_type="SUPPORTS",
properties={"confidence": 0.9},
metadata={"source": "example"},
layer_id=intermediate_layer.id,
)
print(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}")
edge3 = await graph.add_edge(
source_id=node3.id,
target_id=node4.id,
edge_type="EXTENDS",
properties={"confidence": 0.8},
metadata={"source": "example"},
layer_id=derived_layer.id,
)
print(f"Added edge3 with ID: {edge3.id} between {node3.id} and {node4.id}")
# Save the graph to the database
# The graph is automatically saved when nodes and edges are added,
# but for NetworkXAdapter we'll save the file explicitly
if hasattr(adapter, "save_graph_to_file"):
await adapter.save_graph_to_file(adapter.filename)
print(f"Saving graph to file: {adapter.filename}")
# Retrieve all layers
layers = await graph.get_layers()
print(f"Retrieved {len(layers)} layers")
# Load the graph from the database
print(f"Loading graph with ID: {graph.id} from database")
# Create a new graph instance from the database
loaded_graph = LayeredKnowledgeGraph(id=graph.id, name="Test Graph")
loaded_graph.set_adapter(LayeredGraphDBAdapter(adapter))
# Load layers, which will also load nodes and edges
loaded_layers = await loaded_graph.get_layers()
print(f"Successfully loaded graph: {loaded_graph} with {len(loaded_layers)} layers")
# Display contents of each layer
print("\n===== Individual Layer Contents =====")
for layer in layers:
# Get nodes and edges in the layer
nodes = await graph.get_nodes_in_layer(layer.id)
edges = await graph.get_edges_in_layer(layer.id)
# Print summary
print(f"Nodes in {layer.name.lower()} layer: {[node.name for node in nodes]}")
print(f"Edges in {layer.name.lower()} layer: {[edge.edge_type for edge in edges]}")
# Display cumulative layer views
print("\n===== Cumulative Layer Views =====")
# Intermediate layer - should include base layer nodes/edges
print("\nCumulative layer graph for intermediate layer:")
int_nodes, int_edges = await graph.get_cumulative_layer_graph(intermediate_layer.id)
print(f"Intermediate cumulative nodes: {[node.name for node in int_nodes]}")
print(f"Intermediate cumulative edges: {[edge.edge_type for edge in int_edges]}")
# Derived layer - should include all nodes/edges
print("\nCumulative layer graph for derived layer:")
derived_nodes, derived_edges = await graph.get_cumulative_layer_graph(derived_layer.id)
print(f"Derived cumulative nodes: {[node.name for node in derived_nodes]}")
print(f"Derived cumulative edges: {[edge.edge_type for edge in derived_edges]}")
# Test helper methods
print("\n===== Helper Method Results =====")
base_nodes = await graph.get_nodes_in_layer(base_layer.id)
base_edges = await graph.get_edges_in_layer(base_layer.id)
print(f"Base layer contains {len(base_nodes)} nodes and {len(base_edges)} edges")
print("Example complete")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,88 +1,451 @@
import uuid
import json
import logging
from sqlalchemy import select
from typing import List, Any
from typing import List, Dict, Any, Optional, Union, Tuple, Sequence, Protocol, Callable
from contextlib import asynccontextmanager
from cognee.shared.logging_utils import get_logger
from sqlalchemy.future import select
from sqlalchemy import or_
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.modules.data.models import Data
from cognee.infrastructure.engine.models.DataPoint import DataPoint
logger = logging.getLogger(__name__)
# Configure logger
logger = get_logger(name="apply_node_set")
async def apply_node_set(data_points: List[DataPoint]) -> List[DataPoint]:
"""
Apply NodeSet values from the relational store to DataPoint instances.
This task fetches the NodeSet values from the Data model in the relational database
and applies them to the corresponding DataPoint instances.
async def apply_node_set(data: Union[DataPoint, List[DataPoint]]) -> Union[DataPoint, List[DataPoint]]:
"""Apply NodeSet values to DataPoint objects.
Args:
data_points: List of DataPoint instances to process
data: Single DataPoint or list of DataPoints to process
Returns:
List of updated DataPoint instances with NodeSet values applied
The processed DataPoint(s) with updated NodeSet values
"""
logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints")
if not data:
logger.warning("No data provided to apply NodeSet values")
return data
if not data_points:
# Convert single DataPoint to list for uniform processing
data_points = data if isinstance(data, list) else [data]
logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints")
# Process DataPoint objects to apply NodeSet values
updated_data_points = await _process_data_points(data_points)
# Create set nodes for each NodeSet
await _create_set_nodes(updated_data_points)
# Return data in the same format it was received
return data_points if isinstance(data, list) else data_points[0]
async def _process_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
"""Process DataPoint objects to apply NodeSet values from the database.
Args:
data_points: List of DataPoint objects to process
Returns:
The processed list of DataPoints with updated NodeSet values
"""
try:
if not data_points:
return []
# Extract IDs and collect document relationships
data_point_ids, parent_doc_map = _collect_ids_and_relationships(data_points)
# Get NodeSet values from database
nodeset_map = await _fetch_nodesets_from_database(data_point_ids, parent_doc_map)
# Apply NodeSet values to DataPoints
_apply_nodesets_to_datapoints(data_points, nodeset_map, parent_doc_map)
return data_points
except Exception as e:
logger.error(f"Error processing DataPoints: {str(e)}")
return data_points
# Create a map of data_point IDs for efficient lookup
data_point_map = {str(dp.id): dp for dp in data_points}
# Get the database engine
db_engine = get_relational_engine()
# Get session (handles both sync and async cases for testing)
session = db_engine.get_async_session()
try:
# Handle both AsyncMock and actual async context manager for testing
if hasattr(session, "__aenter__"):
# It's a real async context manager
async with session as sess:
await _process_data_points(sess, data_point_map)
else:
# It's an AsyncMock in tests
await _process_data_points(session, data_point_map)
except Exception as e:
logger.error(f"Error applying NodeSet values: {e}")
return data_points
async def _process_data_points(session, data_point_map):
"""
Process data points with the given session.
This helper function handles the actual database query and NodeSet application.
def _collect_ids_and_relationships(data_points: List[DataPoint]) -> Tuple[List[str], Dict[str, str]]:
"""Extract DataPoint IDs and document relationships.
Args:
session: Database session
data_point_map: Map of data point IDs to DataPoint objects
data_points: List of DataPoint objects
Returns:
Tuple containing:
- List of DataPoint IDs
- Dictionary mapping DataPoint IDs to parent document IDs
"""
# Get all data points from the Data table that have node_set values
# and correspond to the data_points we're processing
data_ids = list(data_point_map.keys())
data_point_ids = []
parent_doc_ids = []
parent_doc_map = {}
# Collect all IDs to look up
for dp in data_points:
# Get the DataPoint ID
if hasattr(dp, "id"):
dp_id = str(dp.id)
data_point_ids.append(dp_id)
# Check if there's a parent document to get NodeSet from
if (hasattr(dp, "made_from") and
hasattr(dp.made_from, "is_part_of") and
hasattr(dp.made_from.is_part_of, "id")):
parent_id = str(dp.made_from.is_part_of.id)
parent_doc_ids.append(parent_id)
parent_doc_map[dp_id] = parent_id
logger.info(f"Found {len(data_point_ids)} DataPoint IDs and {len(parent_doc_ids)} parent document IDs")
# Combine all IDs for database lookup
return data_point_ids + parent_doc_ids, parent_doc_map
query = select(Data).where(Data.id.in_(data_ids), Data.node_set.is_not(None))
result = await session.execute(query)
data_records = result.scalars().all()
async def _fetch_nodesets_from_database(ids: List[str], parent_doc_map: Dict[str, str]) -> Dict[str, Any]:
"""Fetch NodeSet values from the database for the given IDs.
Args:
ids: List of IDs to search for
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
Returns:
Dictionary mapping document IDs to their NodeSet values
"""
# Convert string IDs to UUIDs for database lookup
uuid_objects = _convert_ids_to_uuids(ids)
if not uuid_objects:
return {}
# Query the database for NodeSet values
nodeset_map = {}
db_engine = get_relational_engine()
async with db_engine.get_async_session() as sess:
# Query for records matching the IDs
query = select(Data).where(Data.id.in_(uuid_objects))
result = await sess.execute(query)
records = result.scalars().all()
logger.info(f"Found {len(records)} total records matching the IDs")
# Extract NodeSet values from records
for record in records:
if record.node_set is not None:
nodeset_map[str(record.id)] = record.node_set
logger.info(f"Found {len(nodeset_map)} records with non-NULL NodeSet values")
return nodeset_map
# Apply NodeSet values to corresponding DataPoint instances
for data_record in data_records:
data_point_id = str(data_record.id)
if data_point_id in data_point_map and data_record.node_set:
# Parse the JSON string to get the NodeSet
try:
node_set = json.loads(data_record.node_set)
data_point_map[data_point_id].NodeSet = node_set
logger.debug(f"Applied NodeSet {node_set} to DataPoint {data_point_id}")
except json.JSONDecodeError:
logger.warning(f"Failed to parse NodeSet JSON for DataPoint {data_point_id}")
def _convert_ids_to_uuids(ids: List[str]) -> List[uuid.UUID]:
"""Convert string IDs to UUID objects.
Args:
ids: List of string IDs to convert
Returns:
List of UUID objects
"""
uuid_objects = []
for id_str in ids:
try:
uuid_objects.append(uuid.UUID(id_str))
except Exception as e:
logger.warning(f"Failed to convert ID {id_str} to UUID: {str(e)}")
logger.info(f"Converted {len(uuid_objects)} out of {len(ids)} IDs to UUID objects")
return uuid_objects
def _apply_nodesets_to_datapoints(
data_points: List[DataPoint],
nodeset_map: Dict[str, Any],
parent_doc_map: Dict[str, str]
) -> None:
"""Apply NodeSet values to DataPoints.
Args:
data_points: List of DataPoint objects to update
nodeset_map: Dictionary mapping document IDs to their NodeSet values
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
"""
for dp in data_points:
dp_id = str(dp.id)
# Try direct match first
if dp_id in nodeset_map:
nodeset = nodeset_map[dp_id]
logger.info(f"Found NodeSet for {dp_id}: {nodeset}")
dp.NodeSet = nodeset
# Then try parent document
elif dp_id in parent_doc_map and parent_doc_map[dp_id] in nodeset_map:
parent_id = parent_doc_map[dp_id]
nodeset = nodeset_map[parent_id]
logger.info(f"Found NodeSet from parent document {parent_id} for {dp_id}: {nodeset}")
dp.NodeSet = nodeset
async def _create_set_nodes(data_points: List[DataPoint]) -> None:
"""Create set nodes for DataPoints with NodeSets.
Args:
data_points: List of DataPoint objects to process
"""
try:
logger.info(f"Creating set nodes for {len(data_points)} DataPoints")
for dp in data_points:
if not hasattr(dp, "NodeSet") or not dp.NodeSet:
continue
try:
# Create set nodes for the NodeSet (one per value)
document_id = str(dp.id) if hasattr(dp, "id") else None
set_node_ids, edge_ids = await create_set_node(dp.NodeSet, document_id=document_id)
if set_node_ids and len(set_node_ids) > 0:
logger.info(f"Created {len(set_node_ids)} set nodes for NodeSet with {len(dp.NodeSet)} values")
# Store the set node IDs with the DataPoint if possible
try:
# Store as JSON string if multiple IDs, or single ID if only one
if len(set_node_ids) > 1:
dp.SetNodeIds = json.dumps(set_node_ids)
else:
dp.SetNodeId = set_node_ids[0]
except Exception as e:
logger.warning(f"Failed to store set node IDs for NodeSet: {str(e)}")
else:
logger.warning("Failed to create set nodes for NodeSet")
except Exception as e:
logger.error(f"Error creating set nodes: {str(e)}")
except Exception as e:
logger.error(f"Error processing NodeSets: {str(e)}")
logger.info(f"Successfully applied NodeSet values to DataPoints")
async def create_set_node(
nodes: Union[List[str], str],
name: Optional[str] = None,
document_id: Optional[str] = None
) -> Tuple[Optional[List[str]], List[str]]:
"""Create individual nodes for each value in the NodeSet.
Args:
nodes: List of node values or JSON string representing node values
name: Base name for the NodeSet (optional)
document_id: ID of the document containing the NodeSet (optional)
Returns:
Tuple containing:
- List of created set node IDs (or None if creation failed)
- List of created edge IDs
"""
try:
if not nodes:
logger.warning("No nodes provided to create set nodes")
return None, []
# Get the graph engine
graph_engine = await get_graph_engine()
# Parse nodes if provided as JSON string
nodes_list = _parse_nodes_input(nodes)
if not nodes_list:
return None, []
# Base name for the set if not provided
base_name = name or f"NodeSet_{uuid.uuid4().hex[:8]}"
logger.info(f"Creating individual set nodes for {len(nodes_list)} values with base name '{base_name}'")
# Create set nodes using the appropriate strategy
return await _create_set_nodes_unified(graph_engine, nodes_list, base_name, document_id)
except Exception as e:
logger.error(f"Failed to create set nodes: {str(e)}")
return None, []
def _parse_nodes_input(nodes: Union[List[str], str]) -> Optional[List[str]]:
"""Parse the nodes input.
Args:
nodes: List of node values or JSON string representing node values
Returns:
List of node values or None if parsing failed
"""
if isinstance(nodes, str):
try:
parsed_nodes = json.loads(nodes)
logger.info(f"Parsed nodes string into list with {len(parsed_nodes)} items")
return parsed_nodes
except Exception as e:
logger.error(f"Failed to parse nodes as JSON: {str(e)}")
return None
return nodes
async def _create_set_nodes_unified(
graph_engine: Any,
nodes_list: List[str],
base_name: str,
document_id: Optional[str]
) -> Tuple[List[str], List[str]]:
"""Create set nodes using either NetworkX or generic graph engine.
Args:
graph_engine: The graph engine instance
nodes_list: List of node values
base_name: Base name for the NodeSet
document_id: ID of the document containing the NodeSet (optional)
Returns:
Tuple containing:
- List of created set node IDs
- List of created edge IDs
"""
all_set_node_ids = []
all_edge_ids = []
# Define strategies for node and edge creation based on graph engine type
if isinstance(graph_engine, NetworkXAdapter):
# NetworkX-specific strategy
async def create_node(node_value: str) -> str:
set_node_id = str(uuid.uuid4())
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
graph_engine.graph.add_node(
set_node_id,
id=set_node_id,
type="NodeSet",
name=node_name,
node_id=node_value
)
# Validate node creation
if set_node_id in graph_engine.graph.nodes():
node_props = dict(graph_engine.graph.nodes[set_node_id])
logger.info(f"Created set node for value '{node_value}': {json.dumps(node_props)}")
else:
logger.warning(f"Node {set_node_id} not found in graph after adding")
return set_node_id
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
edge_ids = []
try:
edge_id = str(uuid.uuid4())
graph_engine.graph.add_edge(
set_node_id,
node_value,
id=edge_id,
type="CONTAINS"
)
edge_ids.append(edge_id)
except Exception as e:
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
return edge_ids
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
edge_ids = []
try:
doc_to_nodeset_id = str(uuid.uuid4())
graph_engine.graph.add_edge(
document_id,
set_node_id,
id=doc_to_nodeset_id,
type="HAS_NODESET"
)
edge_ids.append(doc_to_nodeset_id)
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
except Exception as e:
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
return edge_ids
# Finalize function for NetworkX
async def finalize() -> None:
await graph_engine.save_graph_to_file(graph_engine.filename)
else:
# Generic graph engine strategy
async def create_node(node_value: str) -> str:
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
set_node_props = {
"name": node_name,
"type": "NodeSet",
"node_id": node_value
}
return await graph_engine.create_node(set_node_props)
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
edge_ids = []
try:
edge_id = await graph_engine.create_edge(
source_id=set_node_id,
target_id=node_value,
edge_type="CONTAINS"
)
edge_ids.append(edge_id)
except Exception as e:
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
return edge_ids
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
edge_ids = []
try:
doc_to_nodeset_id = await graph_engine.create_edge(
source_id=document_id,
target_id=set_node_id,
edge_type="HAS_NODESET"
)
edge_ids.append(doc_to_nodeset_id)
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
except Exception as e:
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
return edge_ids
# Finalize function for generic engine (no-op)
async def finalize() -> None:
pass
# Unified process for both engine types
for node_value in nodes_list:
try:
# Create the node
set_node_id = await create_node(node_value)
# Create edges to the value
value_edge_ids = await create_value_edge(set_node_id, node_value)
all_edge_ids.extend(value_edge_ids)
# Create edges to the document if provided
if document_id:
doc_edge_ids = await create_document_edge(document_id, set_node_id)
all_edge_ids.extend(doc_edge_ids)
all_set_node_ids.append(set_node_id)
except Exception as e:
logger.error(f"Failed to create set node for value '{node_value}': {str(e)}")
# Finalize the process
await finalize()
# Return results
if all_set_node_ids:
logger.info(f"Created {len(all_set_node_ids)} individual set nodes with values: {nodes_list}")
return all_set_node_ids, all_edge_ids
else:
logger.error("Failed to create any set nodes")
return [], []

View file

@ -0,0 +1,675 @@
import asyncio
import os
import json
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional, Set
import webbrowser
from pathlib import Path
import cognee
from cognee.shared.logging_utils import get_logger
# Configure logger
logger = get_logger(name="enhanced_graph_visualization")
# Type aliases for clarity
NodeData = Dict[str, Any]
EdgeData = Dict[str, Any]
GraphData = Tuple[List[Tuple[Any, Dict]], List[Tuple[Any, Any, Optional[str], Dict]]]
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder to handle datetime objects."""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
class NodeSetVisualizer:
"""Class to create enhanced visualizations for NodeSet data in knowledge graphs."""
# Color mapping for different node types
NODE_COLORS = {
"Entity": "#f47710",
"EntityType": "#6510f4",
"DocumentChunk": "#801212",
"TextDocument": "#a83232", # Darker red for documents
"TextSummary": "#1077f4",
"NodeSet": "#ff00ff", # Bright magenta for NodeSet nodes
"Unknown": "#999999",
"default": "#D3D3D3",
}
# Size mapping for different node types
NODE_SIZES = {
"NodeSet": 20, # Larger size for NodeSet nodes
"TextDocument": 18, # Larger size for document nodes
"DocumentChunk": 18, # Larger size for document nodes
"TextSummary": 16, # Medium size for TextSummary nodes
"default": 13, # Default size
}
def __init__(self):
"""Initialize the visualizer."""
self.graph_engine = None
self.nodes_data = []
self.edges_data = []
self.node_count = 0
self.edge_count = 0
self.nodeset_count = 0
async def get_graph_data(self) -> bool:
"""Fetch graph data from the graph engine.
Returns:
bool: True if data was successfully retrieved, False otherwise.
"""
self.graph_engine = await cognee.infrastructure.databases.graph.get_graph_engine()
# Check if the graph exists and has nodes
self.node_count = len(self.graph_engine.graph.nodes())
self.edge_count = len(self.graph_engine.graph.edges())
logger.info(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
print(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
if self.node_count == 0:
logger.error("The graph is empty! Please run a test script first to generate data.")
print("ERROR: The graph is empty! Please run a test script first to generate data.")
return False
graph_data = await self.graph_engine.get_graph_data()
self.nodes_data, self.edges_data = graph_data
# Count NodeSets for status display
self.nodeset_count = sum(1 for _, info in self.nodes_data if info.get("type") == "NodeSet")
return True
def prepare_node_data(self) -> List[NodeData]:
"""Process raw node data to prepare for visualization.
Returns:
List[NodeData]: List of prepared node data objects.
"""
nodes_list = []
# Create a lookup for node types for faster access
node_type_lookup = {str(node_id): node_info.get("type", "Unknown")
for node_id, node_info in self.nodes_data}
for node_id, node_info in self.nodes_data:
# Create a clean copy to avoid modifying the original
processed_node = node_info.copy()
# Remove fields that cause JSON serialization issues
self._clean_node_data(processed_node)
# Add required visualization properties
processed_node["id"] = str(node_id)
node_type = processed_node.get("type", "default")
# Apply visual styling based on node type
processed_node["color"] = self.NODE_COLORS.get(node_type, self.NODE_COLORS["default"])
processed_node["size"] = self.NODE_SIZES.get(node_type, self.NODE_SIZES["default"])
# Create display names
self._format_node_display_name(processed_node, node_type)
nodes_list.append(processed_node)
return nodes_list
@staticmethod
def _clean_node_data(node: NodeData) -> None:
"""Remove fields that might cause JSON serialization issues.
Args:
node: The node data to clean
"""
# Remove non-essential fields that might cause serialization issues
for key in ["created_at", "updated_at", "raw_data_location"]:
if key in node:
del node[key]
@staticmethod
def _format_node_display_name(node: NodeData, node_type: str) -> None:
"""Format the display name for a node.
Args:
node: The node data to process
node_type: The type of the node
"""
# Set a default name if none exists
node["name"] = node.get("name", node.get("id", "Unknown"))
# Special formatting for NodeSet nodes
if node_type == "NodeSet" and "node_id" in node:
node["display_name"] = f"NodeSet: {node['node_id']}"
else:
node["display_name"] = node["name"]
# Truncate long display names
if len(node["display_name"]) > 30:
node["display_name"] = f"{node['display_name'][:27]}..."
def prepare_edge_data(self, nodes_list: List[NodeData]) -> List[EdgeData]:
"""Process raw edge data to prepare for visualization.
Args:
nodes_list: The processed node data
Returns:
List[EdgeData]: List of prepared edge data objects.
"""
links_list = []
# Create a lookup for node types for faster access
node_type_lookup = {node["id"]: node.get("type", "Unknown") for node in nodes_list}
for source, target, relation, edge_info in self.edges_data:
source_str = str(source)
target_str = str(target)
# Skip if source or target not in node_type_lookup (should not happen)
if source_str not in node_type_lookup or target_str not in node_type_lookup:
continue
# Get node types
source_type = node_type_lookup[source_str]
target_type = node_type_lookup[target_str]
# Create edge data
link_data = {
"source": source_str,
"target": target_str,
"relation": relation or "UNKNOWN"
}
# Categorize the edge for styling
link_data["connection_type"] = self._categorize_edge(source_type, target_type)
links_list.append(link_data)
return links_list
@staticmethod
def _categorize_edge(source_type: str, target_type: str) -> str:
"""Categorize an edge based on the connected node types.
Args:
source_type: The type of the source node
target_type: The type of the target node
Returns:
str: The category of the edge
"""
if source_type == "NodeSet" and target_type != "NodeSet":
return "nodeset_to_value"
elif (source_type in ["TextDocument", "DocumentChunk", "TextSummary"]) and target_type == "NodeSet":
return "document_to_nodeset"
elif target_type == "NodeSet":
return "to_nodeset"
elif source_type == "NodeSet":
return "from_nodeset"
else:
return "standard"
def generate_html(self, nodes_list: List[NodeData], links_list: List[EdgeData]) -> str:
"""Generate the HTML visualization with D3.js.
Args:
nodes_list: The processed node data
links_list: The processed edge data
Returns:
str: The HTML content for the visualization
"""
# Use embedded template directly - more reliable than file access
html_template = self._get_embedded_html_template()
# Generate the HTML content with custom JSON encoder for datetime objects
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, cls=DateTimeEncoder))
html_content = html_content.replace("{links}", json.dumps(links_list, cls=DateTimeEncoder))
html_content = html_content.replace("{node_count}", str(self.node_count))
html_content = html_content.replace("{edge_count}", str(self.edge_count))
html_content = html_content.replace("{nodeset_count}", str(self.nodeset_count))
return html_content
def save_html(self, html_content: str) -> str:
"""Save the HTML content to a file and open it in the browser.
Args:
html_content: The HTML content to save
Returns:
str: The path to the saved file
"""
# Create the output file path
output_path = Path.cwd() / "enhanced_nodeset_visualization.html"
# Write the HTML content to the file
with open(output_path, "w") as f:
f.write(html_content)
logger.info(f"Enhanced visualization saved to: {output_path}")
print(f"Enhanced visualization saved to: {output_path}")
# Open the visualization in the default web browser
file_url = f"file://{output_path}"
logger.info(f"Opening visualization in browser: {file_url}")
print(f"Opening enhanced visualization in browser: {file_url}")
webbrowser.open(file_url)
return str(output_path)
@staticmethod
def _get_embedded_html_template() -> str:
"""Get the embedded HTML template as a fallback.
Returns:
str: The HTML template
"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Cognee NodeSet Visualization</title>
<script src="https://d3js.org/d3.v5.min.js"></script>
<style>
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', Arial, sans-serif; }
svg { width: 100vw; height: 100vh; display: block; }
.links line { stroke-width: 2px; }
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
.node-label { font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
/* NodeSet specific styles */
.links line.nodeset_to_value { stroke: #ff00ff; stroke-width: 3px; stroke-dasharray: 5, 5; }
.links line.document_to_nodeset { stroke: #fc0; stroke-width: 3px; }
.links line.to_nodeset { stroke: #0cf; stroke-width: 2px; }
.links line.from_nodeset { stroke: #0fc; stroke-width: 2px; }
.nodes circle.nodeset { stroke: white; stroke-width: 2px; filter: drop-shadow(0 0 10px rgba(255,0,255,0.8)); }
.nodes circle.document { stroke: white; stroke-width: 1.5px; filter: drop-shadow(0 0 8px rgba(255,255,0,0.6)); }
.node-label.nodeset { font-size: 6px; font-weight: bold; fill: white; }
.node-label.document { font-size: 5.5px; font-weight: bold; fill: white; }
/* Legend */
.legend { position: fixed; top: 10px; left: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
.legend-item { display: flex; align-items: center; margin-bottom: 5px; }
.legend-color { width: 15px; height: 15px; margin-right: 10px; border-radius: 50%; }
.legend-label { font-size: 14px; }
/* Edge legend */
.edge-legend { position: fixed; top: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
.edge-legend-item { display: flex; align-items: center; margin-bottom: 10px; }
.edge-line { width: 30px; height: 3px; margin-right: 10px; }
.edge-label { font-size: 14px; }
/* Controls */
.controls { position: fixed; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; display: flex; gap: 10px; }
button { background: #333; color: white; border: none; padding: 5px 10px; border-radius: 4px; cursor: pointer; }
button:hover { background: #555; }
/* Status message */
.status { position: fixed; bottom: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
</style>
</head>
<body>
<svg></svg>
<!-- Node Legend -->
<div class="legend">
<h3 style="margin-top: 0;">Node Types</h3>
<div class="legend-item">
<div class="legend-color" style="background-color: #ff00ff;"></div>
<div class="legend-label">NodeSet</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #a83232;"></div>
<div class="legend-label">TextDocument</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #801212;"></div>
<div class="legend-label">DocumentChunk</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #1077f4;"></div>
<div class="legend-label">TextSummary</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #f47710;"></div>
<div class="legend-label">Entity</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #6510f4;"></div>
<div class="legend-label">EntityType</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #999999;"></div>
<div class="legend-label">Unknown</div>
</div>
</div>
<!-- Edge Legend -->
<div class="edge-legend">
<h3 style="margin-top: 0;">Edge Types</h3>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: #fc0;"></div>
<div class="edge-label">Document NodeSet</div>
</div>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: #ff00ff; height: 3px; background: linear-gradient(to right, #ff00ff 50%, transparent 50%); background-size: 10px 3px; background-repeat: repeat-x;"></div>
<div class="edge-label">NodeSet Value</div>
</div>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: #0cf;"></div>
<div class="edge-label">Any NodeSet</div>
</div>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: rgba(255, 255, 255, 0.4);"></div>
<div class="edge-label">Standard Connection</div>
</div>
</div>
<!-- Controls -->
<div class="controls">
<button id="center-btn">Center Graph</button>
<button id="highlight-nodesets">Highlight NodeSets</button>
<button id="highlight-documents">Highlight Documents</button>
<button id="reset-highlight">Reset Highlight</button>
</div>
<!-- Status -->
<div class="status">
<div>Nodes: {node_count}</div>
<div>Edges: {edge_count}</div>
<div>NodeSets: {nodeset_count}</div>
</div>
<script>
var nodes = {nodes};
var links = {links};
var svg = d3.select("svg"),
width = window.innerWidth,
height = window.innerHeight;
var container = svg.append("g");
// Count NodeSets for status display
const nodesetCount = nodes.filter(n => n.type === "NodeSet").length;
document.querySelector('.status').innerHTML = `
<div>Nodes: ${nodes.length}</div>
<div>Edges: ${links.length}</div>
<div>NodeSets: ${nodesetCount}</div>
`;
var simulation = d3.forceSimulation(nodes)
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
.force("charge", d3.forceManyBody().strength(-300))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("x", d3.forceX().strength(0.1).x(width / 2))
.force("y", d3.forceY().strength(0.1).y(height / 2));
var link = container.append("g")
.attr("class", "links")
.selectAll("line")
.data(links)
.enter().append("line")
.attr("stroke-width", d => {
if (d.connection_type === 'document_to_nodeset' || d.connection_type === 'nodeset_to_value') {
return 3;
}
return 2;
})
.attr("stroke", d => {
switch(d.connection_type) {
case 'document_to_nodeset': return "#fc0";
case 'nodeset_to_value': return "#ff00ff";
case 'to_nodeset': return "#0cf";
case 'from_nodeset': return "#0fc";
default: return "rgba(255, 255, 255, 0.4)";
}
})
.attr("stroke-dasharray", d => d.connection_type === 'nodeset_to_value' ? "5,5" : null)
.attr("class", d => d.connection_type);
var edgeLabels = container.append("g")
.attr("class", "edge-labels")
.selectAll("text")
.data(links)
.enter().append("text")
.attr("class", "edge-label")
.text(d => d.relation);
var nodeGroup = container.append("g")
.attr("class", "nodes")
.selectAll("g")
.data(nodes)
.enter().append("g");
var node = nodeGroup.append("circle")
.attr("r", d => d.size || 13)
.attr("fill", d => d.color)
.attr("class", d => {
if (d.type === "NodeSet") return "nodeset";
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "document";
return "";
})
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("end", dragended));
nodeGroup.append("text")
.attr("class", d => {
if (d.type === "NodeSet") return "node-label nodeset";
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "node-label document";
return "node-label";
})
.attr("dy", 4)
.attr("font-size", d => {
if (d.type === "NodeSet") return "6px";
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "5.5px";
return "5px";
})
.attr("text-anchor", "middle")
.text(d => d.display_name || d.name);
node.append("title").text(d => {
// Create a formatted tooltip with node properties
let props = Object.entries(d)
.filter(([key]) => !["x", "y", "vx", "vy", "index", "fx", "fy", "color", "display_name"].includes(key))
.map(([key, value]) => `${key}: ${value}`)
.join("\\n");
return props;
});
simulation.on("tick", function() {
link.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
edgeLabels
.attr("x", d => (d.source.x + d.target.x) / 2)
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
node.attr("cx", d => d.x)
.attr("cy", d => d.y);
nodeGroup.select("text")
.attr("x", d => d.x)
.attr("y", d => d.y)
.attr("dy", 4)
.attr("text-anchor", "middle");
});
// Add zoom behavior
const zoom = d3.zoom()
.scaleExtent([0.1, 8])
.on("zoom", function() {
container.attr("transform", d3.event.transform);
});
svg.call(zoom);
// Button controls
document.getElementById("center-btn").addEventListener("click", function() {
svg.transition().duration(750).call(
zoom.transform,
d3.zoomIdentity.translate(width / 2, height / 2).scale(1)
);
});
document.getElementById("highlight-nodesets").addEventListener("click", function() {
highlightNodes("NodeSet");
});
document.getElementById("highlight-documents").addEventListener("click", function() {
highlightNodes(["TextDocument", "DocumentChunk"]);
});
document.getElementById("reset-highlight").addEventListener("click", function() {
resetHighlight();
});
function highlightNodes(typeToHighlight) {
// Dim all nodes and links
node.transition().duration(300)
.attr("opacity", 0.2);
link.transition().duration(300)
.attr("opacity", 0.2);
nodeGroup.selectAll("text").transition().duration(300)
.attr("opacity", 0.2);
// Create arrays for types if a single string is provided
const typesToHighlight = Array.isArray(typeToHighlight) ? typeToHighlight : [typeToHighlight];
// Highlight matching nodes and their connected nodes
const highlightedNodeIds = new Set();
// First, find all nodes of the target type
nodes.forEach(n => {
if (typesToHighlight.includes(n.type)) {
highlightedNodeIds.add(n.id);
}
});
// Find all connected nodes (both directions)
links.forEach(l => {
if (highlightedNodeIds.has(l.source.id || l.source)) {
highlightedNodeIds.add(l.target.id || l.target);
}
if (highlightedNodeIds.has(l.target.id || l.target)) {
highlightedNodeIds.add(l.source.id || l.source);
}
});
// Highlight the nodes
node.filter(d => highlightedNodeIds.has(d.id))
.transition().duration(300)
.attr("opacity", 1);
// Highlight the labels
nodeGroup.selectAll("text")
.filter(d => highlightedNodeIds.has(d.id))
.transition().duration(300)
.attr("opacity", 1);
// Highlight the links between highlighted nodes
link.filter(d => {
const sourceId = d.source.id || d.source;
const targetId = d.target.id || d.target;
return highlightedNodeIds.has(sourceId) && highlightedNodeIds.has(targetId);
})
.transition().duration(300)
.attr("opacity", 1);
}
function resetHighlight() {
node.transition().duration(300).attr("opacity", 1);
link.transition().duration(300).attr("opacity", 1);
nodeGroup.selectAll("text").transition().duration(300).attr("opacity", 1);
}
function dragstarted(d) {
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
}
function dragged(d) {
d.fx = d3.event.x;
d.fy = d3.event.y;
}
function dragended(d) {
if (!d3.event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}
window.addEventListener("resize", function() {
width = window.innerWidth;
height = window.innerHeight;
svg.attr("width", width).attr("height", height);
simulation.force("center", d3.forceCenter(width / 2, height / 2));
simulation.alpha(1).restart();
});
</script>
</body>
</html>
"""
async def create_visualization(self) -> Optional[str]:
"""Main method to create the visualization.
Returns:
Optional[str]: Path to the saved visualization file, or None if unsuccessful
"""
print("Creating enhanced NodeSet visualization...")
# Get graph data
if not await self.get_graph_data():
return None
try:
# Process nodes
nodes_list = self.prepare_node_data()
# Process edges
links_list = self.prepare_edge_data(nodes_list)
# Generate HTML
html_content = self.generate_html(nodes_list, links_list)
# Save to file and open in browser
return self.save_html(html_content)
except Exception as e:
logger.error(f"Error creating visualization: {str(e)}")
print(f"Error creating visualization: {str(e)}")
return None
async def main():
"""Main entry point for the script."""
visualizer = NodeSetVisualizer()
await visualizer.create_visualization()
if __name__ == "__main__":
# Run the async main function
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.close()