Fix nodesets
This commit is contained in:
parent
2355d1bfea
commit
95c12fbc1e
7 changed files with 1252 additions and 619 deletions
|
|
@ -16,7 +16,7 @@ async def add(
|
|||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||
dataset_name: str = "main_dataset",
|
||||
user: User = None,
|
||||
NodeSet: Optional[List[str]] = None,
|
||||
node_set: Optional[List[str]] = None,
|
||||
):
|
||||
# Create tables for databases
|
||||
await create_relational_db_and_tables()
|
||||
|
|
@ -37,7 +37,7 @@ async def add(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, NodeSet)]
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, node_set)]
|
||||
|
||||
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
|
||||
pipeline = run_tasks(
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values to DataPoints
|
||||
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values and create set nodes
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
|
|
|||
|
|
@ -1,350 +0,0 @@
|
|||
"""
|
||||
Example demonstrating how to use LayeredKnowledgeGraphDP with database adapters.
|
||||
|
||||
This example shows how to:
|
||||
1. Create a layered knowledge graph
|
||||
2. Set a database adapter
|
||||
3. Add nodes, edges, and layers with automatic persistence to the database
|
||||
4. Retrieve graph data from the database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.graph.datapoint_layered_graph import (
|
||||
GraphNode,
|
||||
GraphEdge,
|
||||
GraphLayer,
|
||||
LayeredKnowledgeGraphDP,
|
||||
)
|
||||
from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def retrieve_graph_manually(graph_id, adapter):
|
||||
"""Retrieve a graph manually from the NetworkX adapter."""
|
||||
graph_db = adapter._graph_db
|
||||
if not hasattr(graph_db, "graph") or not graph_db.graph:
|
||||
await graph_db.load_graph_from_file()
|
||||
|
||||
graph_id_str = str(graph_id)
|
||||
logger.info(f"Looking for graph with ID: {graph_id_str}")
|
||||
|
||||
if hasattr(graph_db, "graph") and graph_db.graph.has_node(graph_id_str):
|
||||
# Get the graph node data
|
||||
graph_data = graph_db.graph.nodes[graph_id_str]
|
||||
logger.info(f"Found graph node data: {graph_data}")
|
||||
|
||||
# Create the graph instance
|
||||
graph = LayeredKnowledgeGraphDP(
|
||||
id=graph_id,
|
||||
name=graph_data.get("name", ""),
|
||||
description=graph_data.get("description", ""),
|
||||
metadata=graph_data.get("metadata", {}),
|
||||
)
|
||||
|
||||
# Set the adapter
|
||||
graph.set_adapter(adapter)
|
||||
|
||||
# Find and add all layers, nodes, and edges
|
||||
nx_graph = graph_db.graph
|
||||
|
||||
# Find all layers for this graph
|
||||
logger.info("Finding layers connected to the graph")
|
||||
found_layers = set()
|
||||
for source, target, key in nx_graph.edges(graph_id_str, keys=True):
|
||||
if key == "CONTAINS_LAYER":
|
||||
# Found layer
|
||||
layer_data = nx_graph.nodes[target]
|
||||
layer_id_str = target
|
||||
logger.info(f"Found layer: {layer_id_str} with data: {layer_data}")
|
||||
|
||||
# Convert parent layers
|
||||
parent_layers = []
|
||||
if "parent_layers" in layer_data:
|
||||
try:
|
||||
if isinstance(layer_data["parent_layers"], str):
|
||||
import json
|
||||
|
||||
parent_layers = [
|
||||
UUID(p) for p in json.loads(layer_data["parent_layers"])
|
||||
]
|
||||
elif isinstance(layer_data["parent_layers"], list):
|
||||
parent_layers = [UUID(str(p)) for p in layer_data["parent_layers"]]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing parent layers: {e}")
|
||||
parent_layers = []
|
||||
|
||||
# Create layer
|
||||
try:
|
||||
layer = GraphLayer(
|
||||
id=UUID(layer_id_str),
|
||||
name=layer_data.get("name", ""),
|
||||
description=layer_data.get("description", ""),
|
||||
layer_type=layer_data.get("layer_type", "default"),
|
||||
parent_layers=parent_layers,
|
||||
properties=layer_data.get("properties", {}),
|
||||
metadata=layer_data.get("metadata", {}),
|
||||
)
|
||||
graph.layers[layer.id] = layer
|
||||
found_layers.add(layer_id_str)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating layer object: {e}")
|
||||
|
||||
# Helper function to safely get UUID
|
||||
def safe_uuid(value):
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
try:
|
||||
return UUID(str(value))
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting to UUID: {value} - {e}")
|
||||
return None
|
||||
|
||||
# Find all nodes for this graph
|
||||
logger.info("Finding nodes in the layers")
|
||||
for node_id, node_data in nx_graph.nodes(data=True):
|
||||
# First check if this is a node by its metadata
|
||||
if node_data.get("metadata", {}).get("type") == "GraphNode":
|
||||
# Get the layer ID from the node data
|
||||
if "layer_id" in node_data:
|
||||
layer_id_str = node_data["layer_id"]
|
||||
# Check if this layer ID is in our found layers
|
||||
try:
|
||||
layer_id = safe_uuid(layer_id_str)
|
||||
if layer_id and layer_id in graph.layers:
|
||||
logger.info(f"Found node with ID {node_id} in layer {layer_id_str}")
|
||||
|
||||
# Create the node
|
||||
node = GraphNode(
|
||||
id=safe_uuid(node_id),
|
||||
name=node_data.get("name", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
description=node_data.get("description", ""),
|
||||
properties=node_data.get("properties", {}),
|
||||
layer_id=layer_id,
|
||||
metadata=node_data.get("metadata", {}),
|
||||
)
|
||||
graph.nodes[node.id] = node
|
||||
graph.node_layer_map[node.id] = layer_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node_id}: {e}")
|
||||
|
||||
# Find all edges for this graph
|
||||
logger.info("Finding edges in the layers")
|
||||
for node_id, node_data in nx_graph.nodes(data=True):
|
||||
# First check if this is an edge by its metadata
|
||||
if node_data.get("metadata", {}).get("type") == "GraphEdge":
|
||||
# Get the layer ID from the edge data
|
||||
if "layer_id" in node_data:
|
||||
layer_id_str = node_data["layer_id"]
|
||||
# Check if this layer ID is in our found layers
|
||||
try:
|
||||
layer_id = safe_uuid(layer_id_str)
|
||||
if layer_id and layer_id in graph.layers:
|
||||
source_id = safe_uuid(node_data.get("source_node_id"))
|
||||
target_id = safe_uuid(node_data.get("target_node_id"))
|
||||
|
||||
if source_id and target_id:
|
||||
logger.info(f"Found edge with ID {node_id} in layer {layer_id_str}")
|
||||
|
||||
# Create the edge
|
||||
edge = GraphEdge(
|
||||
id=safe_uuid(node_id),
|
||||
source_node_id=source_id,
|
||||
target_node_id=target_id,
|
||||
relationship_name=node_data.get("relationship_name", ""),
|
||||
properties=node_data.get("properties", {}),
|
||||
layer_id=layer_id,
|
||||
metadata=node_data.get("metadata", {}),
|
||||
)
|
||||
graph.edges[edge.id] = edge
|
||||
graph.edge_layer_map[edge.id] = layer_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing edge {node_id}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Manually retrieved graph with {len(graph.layers)} layers, {len(graph.nodes)} nodes, and {len(graph.edges)} edges"
|
||||
)
|
||||
return graph
|
||||
else:
|
||||
logger.error(f"Could not find graph with ID {graph_id}")
|
||||
return None
|
||||
|
||||
|
||||
def get_nodes_and_edges_from_graph(graph, layer_id):
|
||||
"""
|
||||
Get nodes and edges from a layer in the graph, avoiding database calls.
|
||||
|
||||
Args:
|
||||
graph: The LayeredKnowledgeGraphDP instance
|
||||
layer_id: The UUID of the layer
|
||||
|
||||
Returns:
|
||||
Tuple of (nodes, edges) lists
|
||||
"""
|
||||
nodes = [node for node in graph.nodes.values() if node.layer_id == layer_id]
|
||||
edges = [edge for edge in graph.edges.values() if edge.layer_id == layer_id]
|
||||
return nodes, edges
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("Starting layered graph database example")
|
||||
|
||||
# Get the default graph engine (typically NetworkXAdapter)
|
||||
graph_db = await get_graph_engine()
|
||||
logger.info(f"Using graph database adapter: {type(graph_db).__name__}")
|
||||
|
||||
# Create an adapter using the graph engine
|
||||
adapter = LayeredGraphDBAdapter(graph_db)
|
||||
|
||||
# Create a new empty graph
|
||||
graph = LayeredKnowledgeGraphDP.create_empty(
|
||||
name="Example Database Graph",
|
||||
description="A graph that persists to the database",
|
||||
metadata={
|
||||
"type": "LayeredKnowledgeGraph",
|
||||
"index_fields": ["name"],
|
||||
}, # Ensure proper metadata
|
||||
)
|
||||
logger.info(f"Created graph with ID: {graph.id}")
|
||||
|
||||
# Set the adapter for this graph
|
||||
graph.set_adapter(adapter)
|
||||
|
||||
# Create and add a base layer
|
||||
base_layer = GraphLayer.create(
|
||||
name="Base Layer", description="The foundation layer of the graph", layer_type="base"
|
||||
)
|
||||
graph.add_layer(base_layer)
|
||||
logger.info(f"Added base layer with ID: {base_layer.id}")
|
||||
|
||||
# Create and add a derived layer that extends the base layer
|
||||
derived_layer = GraphLayer.create(
|
||||
name="Derived Layer",
|
||||
description="A layer that extends the base layer",
|
||||
layer_type="derived",
|
||||
parent_layers=[base_layer.id],
|
||||
)
|
||||
graph.add_layer(derived_layer)
|
||||
logger.info(f"Added derived layer with ID: {derived_layer.id}")
|
||||
|
||||
# Create and add nodes to the base layer
|
||||
node1 = GraphNode.create(
|
||||
name="Concept A", node_type="concept", description="A foundational concept"
|
||||
)
|
||||
graph.add_node(node1, base_layer.id)
|
||||
logger.info(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}")
|
||||
|
||||
node2 = GraphNode.create(
|
||||
name="Concept B", node_type="concept", description="Another foundational concept"
|
||||
)
|
||||
graph.add_node(node2, base_layer.id)
|
||||
logger.info(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}")
|
||||
|
||||
# Create and add a node to the derived layer
|
||||
node3 = GraphNode.create(
|
||||
name="Derived Concept",
|
||||
node_type="concept",
|
||||
description="A concept derived from foundational concepts",
|
||||
)
|
||||
graph.add_node(node3, derived_layer.id)
|
||||
logger.info(f"Added node3 with ID: {node3.id} to layer: {derived_layer.id}")
|
||||
|
||||
# Create and add edges between nodes
|
||||
edge1 = GraphEdge.create(
|
||||
source_node_id=node1.id, target_node_id=node2.id, relationship_name="RELATES_TO"
|
||||
)
|
||||
graph.add_edge(edge1, base_layer.id)
|
||||
logger.info(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}")
|
||||
|
||||
edge2 = GraphEdge.create(
|
||||
source_node_id=node1.id, target_node_id=node3.id, relationship_name="EXPANDS_TO"
|
||||
)
|
||||
graph.add_edge(edge2, derived_layer.id)
|
||||
logger.info(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}")
|
||||
|
||||
edge3 = GraphEdge.create(
|
||||
source_node_id=node2.id, target_node_id=node3.id, relationship_name="CONTRIBUTES_TO"
|
||||
)
|
||||
graph.add_edge(edge3, derived_layer.id)
|
||||
logger.info(f"Added edge3 with ID: {edge3.id} between {node2.id} and {node3.id}")
|
||||
|
||||
# Save the graph state to a file for NetworkXAdapter
|
||||
if hasattr(graph_db, "save_graph_to_file"):
|
||||
logger.info(f"Saving graph to file: {getattr(graph_db, 'filename', 'unknown')}")
|
||||
await graph_db.save_graph_to_file()
|
||||
|
||||
# Persist the entire graph to the database
|
||||
# This is optional since the graph is already being persisted incrementally
|
||||
# when add_layer, add_node, and add_edge are called
|
||||
logger.info("Persisting entire graph to database")
|
||||
graph_id = await graph.persist()
|
||||
logger.info(f"Graph persisted with ID: {graph_id}")
|
||||
|
||||
# Check if the graph exists in the database
|
||||
if hasattr(graph_db, "graph"):
|
||||
logger.info(f"Checking if graph exists in memory: {graph_db.graph.has_node(str(graph.id))}")
|
||||
|
||||
# Check the node data
|
||||
if graph_db.graph.has_node(str(graph.id)):
|
||||
node_data = graph_db.graph.nodes[str(graph.id)]
|
||||
logger.info(f"Graph node data: {node_data}")
|
||||
|
||||
# List all nodes in the graph
|
||||
logger.info("Nodes in the graph:")
|
||||
for node_id, node_data in graph_db.graph.nodes(data=True):
|
||||
logger.info(
|
||||
f" Node {node_id}: type={node_data.get('metadata', {}).get('type', 'unknown')}"
|
||||
)
|
||||
|
||||
# Try to retrieve the graph using our from_database method
|
||||
retrieved_graph = None
|
||||
try:
|
||||
logger.info(f"Retrieving graph from database with ID: {graph.id}")
|
||||
retrieved_graph = await LayeredKnowledgeGraphDP.from_database(graph.id, adapter)
|
||||
logger.info(f"Retrieved graph: {retrieved_graph}")
|
||||
logger.info(
|
||||
f"Retrieved {len(retrieved_graph.layers)} layers, {len(retrieved_graph.nodes)} nodes, and {len(retrieved_graph.edges)} edges"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error retrieving graph using from_database: {str(e)}")
|
||||
|
||||
# Try using manual retrieval as a fallback
|
||||
logger.info("Trying manual retrieval as a fallback")
|
||||
retrieved_graph = await retrieve_graph_manually(graph.id, adapter)
|
||||
|
||||
if retrieved_graph:
|
||||
logger.info(f"Successfully retrieved graph manually: {retrieved_graph}")
|
||||
else:
|
||||
logger.error("Failed to retrieve graph manually")
|
||||
return
|
||||
|
||||
# Use our helper function to get nodes and edges
|
||||
if retrieved_graph:
|
||||
# Get nodes in the base layer
|
||||
base_nodes, base_edges = get_nodes_and_edges_from_graph(retrieved_graph, base_layer.id)
|
||||
logger.info(f"Nodes in base layer: {[node.name for node in base_nodes]}")
|
||||
logger.info(f"Edges in base layer: {[edge.relationship_name for edge in base_edges]}")
|
||||
|
||||
# Get nodes in the derived layer
|
||||
derived_nodes, derived_edges = get_nodes_and_edges_from_graph(
|
||||
retrieved_graph, derived_layer.id
|
||||
)
|
||||
logger.info(f"Nodes in derived layer: {[node.name for node in derived_nodes]}")
|
||||
logger.info(f"Edges in derived layer: {[edge.relationship_name for edge in derived_edges]}")
|
||||
else:
|
||||
logger.error("No graph was retrieved, cannot display nodes and edges")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
146
cognee/examples/node_set_test.py
Normal file
146
cognee/examples/node_set_test.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
import asyncio
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
|
||||
# Replace incorrect import with proper cognee.prune module
|
||||
# from cognee.infrastructure.data_storage import reset_cognee
|
||||
|
||||
# Remove incorrect imports and use high-level API
|
||||
# from cognee.services.cognify import cognify
|
||||
# from cognee.services.search import search
|
||||
from cognee.tasks.node_set.apply_node_set import apply_node_set
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
|
||||
# Configure logging to see detailed output
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set up httpx client for debugging
|
||||
httpx.Client(timeout=None)
|
||||
|
||||
|
||||
async def generate_test_node_ids(count: int) -> List[str]:
|
||||
"""Generate unique node IDs for testing."""
|
||||
return [str(uuid.uuid4()) for _ in range(count)]
|
||||
|
||||
|
||||
async def add_text_with_nodeset(text: str, node_set: List[str], text_id: str = None) -> str:
|
||||
"""Add text with NodeSet to cognee."""
|
||||
# Generate text ID if not provided - but note we can't directly set the ID
|
||||
if text_id is None:
|
||||
text_id = str(uuid.uuid4())
|
||||
|
||||
# Print NodeSet details
|
||||
logger.info(f"Adding text with NodeSet")
|
||||
logger.info(f"NodeSet for this text: {node_set}")
|
||||
|
||||
# Use high-level cognee.add() with the correct parameters
|
||||
await cognee.add(text, node_set=node_set)
|
||||
logger.info(f"Saved text with NodeSet to database")
|
||||
|
||||
return text_id # Note: we can't control the actual ID generated by cognee.add()
|
||||
|
||||
|
||||
async def check_data_records():
|
||||
"""Check for data records in the database to verify NodeSet storage."""
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from sqlalchemy import select
|
||||
query = select(Data)
|
||||
result = await session.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
logger.info(f"Found {len(records)} records in the database")
|
||||
for record in records:
|
||||
logger.info(f"Record ID: {record.id}, name: {record.name}, node_set: {record.node_set}")
|
||||
# Print raw_data_location to see where the text content is stored
|
||||
logger.info(f"Raw data location: {record.raw_data_location}")
|
||||
|
||||
|
||||
async def run_simple_node_set_test():
|
||||
"""Run a simple test of NodeSet functionality."""
|
||||
try:
|
||||
# Reset cognee data using correct module
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
logger.info("Cognee data reset complete")
|
||||
|
||||
# Generate test node IDs for two NodeSets
|
||||
# nodeset1 = await generate_test_node_ids(3)
|
||||
nodeset2 = await generate_test_node_ids(3)
|
||||
nodeset1 = ["test","horse", "mamamama"]
|
||||
|
||||
logger.info(f"Created test node IDs for NodeSet 1: {nodeset1}")
|
||||
logger.info(f"Created test node IDs for NodeSet 2: {nodeset2}")
|
||||
|
||||
# Add two texts with NodeSets
|
||||
text1 = "Natural Language Processing (NLP) is a field of artificial intelligence focused on enabling computers to understand and process human language. It involves techniques for language analysis, translation, and generation."
|
||||
text2 = "Artificial Intelligence (AI) systems are designed to perform tasks that typically require human intelligence. These tasks include speech recognition, decision-making, and language translation."
|
||||
|
||||
text1_id = await add_text_with_nodeset(text1, nodeset1)
|
||||
text2_id = await add_text_with_nodeset(text2, nodeset2)
|
||||
|
||||
# Verify that the NodeSets were stored correctly
|
||||
await check_data_records()
|
||||
|
||||
# Run the cognify process to create a knowledge graph
|
||||
pipeline_run_id = await cognee.cognify()
|
||||
logger.info(f"Cognify process completed with pipeline run ID: {pipeline_run_id}")
|
||||
|
||||
# Skip graph search as the NetworkXAdapter doesn't support Cypher
|
||||
logger.info("Skipping graph search as NetworkXAdapter doesn't support Cypher")
|
||||
|
||||
# Search for insights related to the added texts
|
||||
search_query = "NLP and AI"
|
||||
logger.info(f"Searching for insights with query: '{search_query}'")
|
||||
search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=search_query)
|
||||
|
||||
# Extract NodeSet information from search results
|
||||
logger.info(f"Found {len(search_results)} search results for '{search_query}':")
|
||||
for i, result in enumerate(search_results):
|
||||
logger.info(f"Result {i+1} text: {getattr(result, 'text', 'No text')}")
|
||||
|
||||
# Check for NodeSet and SetNodeId
|
||||
node_set = getattr(result, "NodeSet", None)
|
||||
set_node_id = getattr(result, "SetNodeId", None)
|
||||
logger.info(f"Result {i+1} - NodeSet: {node_set}, SetNodeId: {set_node_id}")
|
||||
|
||||
# Check id and type
|
||||
logger.info(f"Result {i+1} - ID: {getattr(result, 'id', 'No ID')}, Type: {getattr(result, 'type', 'No type')}")
|
||||
|
||||
# Check if this is a document chunk and has is_part_of property
|
||||
if hasattr(result, "is_part_of") and result.is_part_of:
|
||||
logger.info(f"Result {i+1} is a document chunk with parent ID: {result.is_part_of.id}")
|
||||
|
||||
# Check if the parent has a NodeSet
|
||||
parent_has_nodeset = hasattr(result.is_part_of, "NodeSet") and result.is_part_of.NodeSet
|
||||
logger.info(f" Parent has NodeSet: {parent_has_nodeset}")
|
||||
if parent_has_nodeset:
|
||||
logger.info(f" Parent NodeSet: {result.is_part_of.NodeSet}")
|
||||
|
||||
# Print all attributes of the result to see what's available
|
||||
logger.info(f"Result {i+1} - All attributes: {dir(result)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in simple NodeSet test: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger(level=ERROR)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(run_simple_node_set_test())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
"""
|
||||
Example demonstrating how to use the simplified LayeredKnowledgeGraph with database adapters.
|
||||
|
||||
This example shows how to:
|
||||
1. Create a layered knowledge graph
|
||||
2. Add nodes, edges, and layers
|
||||
3. Retrieve layer data
|
||||
4. Work with cumulative layers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
import os
|
||||
|
||||
from cognee.modules.graph.simplified_layered_graph import (
|
||||
LayeredKnowledgeGraph,
|
||||
GraphNode,
|
||||
GraphEdge,
|
||||
GraphLayer,
|
||||
)
|
||||
from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
print("Starting simplified layered graph example")
|
||||
|
||||
# Initialize file path for the NetworkXAdapter
|
||||
db_dir = os.path.join(os.path.expanduser("~"), "cognee/cognee/.cognee_system/databases")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
db_file = os.path.join(db_dir, "cognee_graph.pkl")
|
||||
|
||||
# Use NetworkXAdapter for the graph database
|
||||
adapter = NetworkXAdapter(filename=db_file)
|
||||
|
||||
# Initialize the adapter by creating or loading the graph
|
||||
if not os.path.exists(db_file):
|
||||
await adapter.create_empty_graph(db_file)
|
||||
await adapter.load_graph_from_file()
|
||||
|
||||
print(f"Using graph database adapter: {adapter.__class__.__name__}")
|
||||
|
||||
# Create an empty graph
|
||||
graph = LayeredKnowledgeGraph.create_empty("Test Graph")
|
||||
graph.set_adapter(LayeredGraphDBAdapter(adapter))
|
||||
print(f"Created graph with ID: {graph.id}")
|
||||
|
||||
# Create layers
|
||||
base_layer = await graph.add_layer(
|
||||
name="Base Layer", description="The foundation layer with base concepts", layer_type="base"
|
||||
)
|
||||
print(f"Added base layer with ID: {base_layer.id}")
|
||||
|
||||
intermediate_layer = await graph.add_layer(
|
||||
name="Intermediate Layer",
|
||||
description="Layer that builds on base concepts",
|
||||
layer_type="intermediate",
|
||||
parent_layers=[base_layer.id],
|
||||
)
|
||||
print(f"Added intermediate layer with ID: {intermediate_layer.id}")
|
||||
|
||||
derived_layer = await graph.add_layer(
|
||||
name="Derived Layer",
|
||||
description="Final layer with derived concepts",
|
||||
layer_type="derived",
|
||||
parent_layers=[intermediate_layer.id],
|
||||
)
|
||||
print(f"Added derived layer with ID: {derived_layer.id}")
|
||||
|
||||
# Add nodes to layers
|
||||
node1 = await graph.add_node(
|
||||
name="Base Concept A",
|
||||
node_type="concept",
|
||||
properties={"importance": "high"},
|
||||
metadata={"source": "example"},
|
||||
layer_id=base_layer.id,
|
||||
)
|
||||
print(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}")
|
||||
|
||||
node2 = await graph.add_node(
|
||||
name="Base Concept B",
|
||||
node_type="concept",
|
||||
properties={"importance": "medium"},
|
||||
metadata={"source": "example"},
|
||||
layer_id=base_layer.id,
|
||||
)
|
||||
print(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}")
|
||||
|
||||
node3 = await graph.add_node(
|
||||
name="Intermediate Concept",
|
||||
node_type="concept",
|
||||
properties={"derived_from": ["Base Concept A"]},
|
||||
metadata={"source": "example"},
|
||||
layer_id=intermediate_layer.id,
|
||||
)
|
||||
print(f"Added node3 with ID: {node3.id} to layer: {intermediate_layer.id}")
|
||||
|
||||
node4 = await graph.add_node(
|
||||
name="Derived Concept",
|
||||
node_type="concept",
|
||||
properties={"derived_from": ["Intermediate Concept"]},
|
||||
metadata={"source": "example"},
|
||||
layer_id=derived_layer.id,
|
||||
)
|
||||
print(f"Added node4 with ID: {node4.id} to layer: {derived_layer.id}")
|
||||
|
||||
# Add edges between nodes
|
||||
edge1 = await graph.add_edge(
|
||||
source_id=node1.id,
|
||||
target_id=node2.id,
|
||||
edge_type="RELATES_TO",
|
||||
properties={"strength": "high"},
|
||||
metadata={"source": "example"},
|
||||
layer_id=base_layer.id,
|
||||
)
|
||||
print(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}")
|
||||
|
||||
edge2 = await graph.add_edge(
|
||||
source_id=node1.id,
|
||||
target_id=node3.id,
|
||||
edge_type="SUPPORTS",
|
||||
properties={"confidence": 0.9},
|
||||
metadata={"source": "example"},
|
||||
layer_id=intermediate_layer.id,
|
||||
)
|
||||
print(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}")
|
||||
|
||||
edge3 = await graph.add_edge(
|
||||
source_id=node3.id,
|
||||
target_id=node4.id,
|
||||
edge_type="EXTENDS",
|
||||
properties={"confidence": 0.8},
|
||||
metadata={"source": "example"},
|
||||
layer_id=derived_layer.id,
|
||||
)
|
||||
print(f"Added edge3 with ID: {edge3.id} between {node3.id} and {node4.id}")
|
||||
|
||||
# Save the graph to the database
|
||||
# The graph is automatically saved when nodes and edges are added,
|
||||
# but for NetworkXAdapter we'll save the file explicitly
|
||||
if hasattr(adapter, "save_graph_to_file"):
|
||||
await adapter.save_graph_to_file(adapter.filename)
|
||||
print(f"Saving graph to file: {adapter.filename}")
|
||||
|
||||
# Retrieve all layers
|
||||
layers = await graph.get_layers()
|
||||
print(f"Retrieved {len(layers)} layers")
|
||||
|
||||
# Load the graph from the database
|
||||
print(f"Loading graph with ID: {graph.id} from database")
|
||||
# Create a new graph instance from the database
|
||||
loaded_graph = LayeredKnowledgeGraph(id=graph.id, name="Test Graph")
|
||||
loaded_graph.set_adapter(LayeredGraphDBAdapter(adapter))
|
||||
# Load layers, which will also load nodes and edges
|
||||
loaded_layers = await loaded_graph.get_layers()
|
||||
print(f"Successfully loaded graph: {loaded_graph} with {len(loaded_layers)} layers")
|
||||
|
||||
# Display contents of each layer
|
||||
print("\n===== Individual Layer Contents =====")
|
||||
for layer in layers:
|
||||
# Get nodes and edges in the layer
|
||||
nodes = await graph.get_nodes_in_layer(layer.id)
|
||||
edges = await graph.get_edges_in_layer(layer.id)
|
||||
|
||||
# Print summary
|
||||
print(f"Nodes in {layer.name.lower()} layer: {[node.name for node in nodes]}")
|
||||
print(f"Edges in {layer.name.lower()} layer: {[edge.edge_type for edge in edges]}")
|
||||
|
||||
# Display cumulative layer views
|
||||
print("\n===== Cumulative Layer Views =====")
|
||||
|
||||
# Intermediate layer - should include base layer nodes/edges
|
||||
print("\nCumulative layer graph for intermediate layer:")
|
||||
int_nodes, int_edges = await graph.get_cumulative_layer_graph(intermediate_layer.id)
|
||||
print(f"Intermediate cumulative nodes: {[node.name for node in int_nodes]}")
|
||||
print(f"Intermediate cumulative edges: {[edge.edge_type for edge in int_edges]}")
|
||||
|
||||
# Derived layer - should include all nodes/edges
|
||||
print("\nCumulative layer graph for derived layer:")
|
||||
derived_nodes, derived_edges = await graph.get_cumulative_layer_graph(derived_layer.id)
|
||||
print(f"Derived cumulative nodes: {[node.name for node in derived_nodes]}")
|
||||
print(f"Derived cumulative edges: {[edge.edge_type for edge in derived_edges]}")
|
||||
|
||||
# Test helper methods
|
||||
print("\n===== Helper Method Results =====")
|
||||
base_nodes = await graph.get_nodes_in_layer(base_layer.id)
|
||||
base_edges = await graph.get_edges_in_layer(base_layer.id)
|
||||
print(f"Base layer contains {len(base_nodes)} nodes and {len(base_edges)} edges")
|
||||
|
||||
print("Example complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,88 +1,451 @@
|
|||
import uuid
|
||||
import json
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
from typing import List, Any
|
||||
from typing import List, Dict, Any, Optional, Union, Tuple, Sequence, Protocol, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import or_
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Configure logger
|
||||
logger = get_logger(name="apply_node_set")
|
||||
|
||||
|
||||
async def apply_node_set(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
"""
|
||||
Apply NodeSet values from the relational store to DataPoint instances.
|
||||
|
||||
This task fetches the NodeSet values from the Data model in the relational database
|
||||
and applies them to the corresponding DataPoint instances.
|
||||
|
||||
async def apply_node_set(data: Union[DataPoint, List[DataPoint]]) -> Union[DataPoint, List[DataPoint]]:
|
||||
"""Apply NodeSet values to DataPoint objects.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint instances to process
|
||||
|
||||
data: Single DataPoint or list of DataPoints to process
|
||||
|
||||
Returns:
|
||||
List of updated DataPoint instances with NodeSet values applied
|
||||
The processed DataPoint(s) with updated NodeSet values
|
||||
"""
|
||||
logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints")
|
||||
if not data:
|
||||
logger.warning("No data provided to apply NodeSet values")
|
||||
return data
|
||||
|
||||
if not data_points:
|
||||
# Convert single DataPoint to list for uniform processing
|
||||
data_points = data if isinstance(data, list) else [data]
|
||||
logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints")
|
||||
|
||||
# Process DataPoint objects to apply NodeSet values
|
||||
updated_data_points = await _process_data_points(data_points)
|
||||
|
||||
# Create set nodes for each NodeSet
|
||||
await _create_set_nodes(updated_data_points)
|
||||
|
||||
# Return data in the same format it was received
|
||||
return data_points if isinstance(data, list) else data_points[0]
|
||||
|
||||
|
||||
async def _process_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
"""Process DataPoint objects to apply NodeSet values from the database.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to process
|
||||
|
||||
Returns:
|
||||
The processed list of DataPoints with updated NodeSet values
|
||||
"""
|
||||
try:
|
||||
if not data_points:
|
||||
return []
|
||||
|
||||
# Extract IDs and collect document relationships
|
||||
data_point_ids, parent_doc_map = _collect_ids_and_relationships(data_points)
|
||||
|
||||
# Get NodeSet values from database
|
||||
nodeset_map = await _fetch_nodesets_from_database(data_point_ids, parent_doc_map)
|
||||
|
||||
# Apply NodeSet values to DataPoints
|
||||
_apply_nodesets_to_datapoints(data_points, nodeset_map, parent_doc_map)
|
||||
|
||||
return data_points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing DataPoints: {str(e)}")
|
||||
return data_points
|
||||
|
||||
# Create a map of data_point IDs for efficient lookup
|
||||
data_point_map = {str(dp.id): dp for dp in data_points}
|
||||
|
||||
# Get the database engine
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
# Get session (handles both sync and async cases for testing)
|
||||
session = db_engine.get_async_session()
|
||||
|
||||
try:
|
||||
# Handle both AsyncMock and actual async context manager for testing
|
||||
if hasattr(session, "__aenter__"):
|
||||
# It's a real async context manager
|
||||
async with session as sess:
|
||||
await _process_data_points(sess, data_point_map)
|
||||
else:
|
||||
# It's an AsyncMock in tests
|
||||
await _process_data_points(session, data_point_map)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying NodeSet values: {e}")
|
||||
|
||||
return data_points
|
||||
|
||||
|
||||
async def _process_data_points(session, data_point_map):
|
||||
"""
|
||||
Process data points with the given session.
|
||||
|
||||
This helper function handles the actual database query and NodeSet application.
|
||||
|
||||
def _collect_ids_and_relationships(data_points: List[DataPoint]) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""Extract DataPoint IDs and document relationships.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
data_point_map: Map of data point IDs to DataPoint objects
|
||||
data_points: List of DataPoint objects
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of DataPoint IDs
|
||||
- Dictionary mapping DataPoint IDs to parent document IDs
|
||||
"""
|
||||
# Get all data points from the Data table that have node_set values
|
||||
# and correspond to the data_points we're processing
|
||||
data_ids = list(data_point_map.keys())
|
||||
data_point_ids = []
|
||||
parent_doc_ids = []
|
||||
parent_doc_map = {}
|
||||
|
||||
# Collect all IDs to look up
|
||||
for dp in data_points:
|
||||
# Get the DataPoint ID
|
||||
if hasattr(dp, "id"):
|
||||
dp_id = str(dp.id)
|
||||
data_point_ids.append(dp_id)
|
||||
|
||||
# Check if there's a parent document to get NodeSet from
|
||||
if (hasattr(dp, "made_from") and
|
||||
hasattr(dp.made_from, "is_part_of") and
|
||||
hasattr(dp.made_from.is_part_of, "id")):
|
||||
|
||||
parent_id = str(dp.made_from.is_part_of.id)
|
||||
parent_doc_ids.append(parent_id)
|
||||
parent_doc_map[dp_id] = parent_id
|
||||
|
||||
logger.info(f"Found {len(data_point_ids)} DataPoint IDs and {len(parent_doc_ids)} parent document IDs")
|
||||
|
||||
# Combine all IDs for database lookup
|
||||
return data_point_ids + parent_doc_ids, parent_doc_map
|
||||
|
||||
query = select(Data).where(Data.id.in_(data_ids), Data.node_set.is_not(None))
|
||||
|
||||
result = await session.execute(query)
|
||||
data_records = result.scalars().all()
|
||||
async def _fetch_nodesets_from_database(ids: List[str], parent_doc_map: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Fetch NodeSet values from the database for the given IDs.
|
||||
|
||||
Args:
|
||||
ids: List of IDs to search for
|
||||
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document IDs to their NodeSet values
|
||||
"""
|
||||
# Convert string IDs to UUIDs for database lookup
|
||||
uuid_objects = _convert_ids_to_uuids(ids)
|
||||
if not uuid_objects:
|
||||
return {}
|
||||
|
||||
# Query the database for NodeSet values
|
||||
nodeset_map = {}
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as sess:
|
||||
# Query for records matching the IDs
|
||||
query = select(Data).where(Data.id.in_(uuid_objects))
|
||||
result = await sess.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
logger.info(f"Found {len(records)} total records matching the IDs")
|
||||
|
||||
# Extract NodeSet values from records
|
||||
for record in records:
|
||||
if record.node_set is not None:
|
||||
nodeset_map[str(record.id)] = record.node_set
|
||||
|
||||
logger.info(f"Found {len(nodeset_map)} records with non-NULL NodeSet values")
|
||||
|
||||
return nodeset_map
|
||||
|
||||
# Apply NodeSet values to corresponding DataPoint instances
|
||||
for data_record in data_records:
|
||||
data_point_id = str(data_record.id)
|
||||
if data_point_id in data_point_map and data_record.node_set:
|
||||
# Parse the JSON string to get the NodeSet
|
||||
try:
|
||||
node_set = json.loads(data_record.node_set)
|
||||
data_point_map[data_point_id].NodeSet = node_set
|
||||
logger.debug(f"Applied NodeSet {node_set} to DataPoint {data_point_id}")
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse NodeSet JSON for DataPoint {data_point_id}")
|
||||
|
||||
def _convert_ids_to_uuids(ids: List[str]) -> List[uuid.UUID]:
|
||||
"""Convert string IDs to UUID objects.
|
||||
|
||||
Args:
|
||||
ids: List of string IDs to convert
|
||||
|
||||
Returns:
|
||||
List of UUID objects
|
||||
"""
|
||||
uuid_objects = []
|
||||
for id_str in ids:
|
||||
try:
|
||||
uuid_objects.append(uuid.UUID(id_str))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert ID {id_str} to UUID: {str(e)}")
|
||||
|
||||
logger.info(f"Converted {len(uuid_objects)} out of {len(ids)} IDs to UUID objects")
|
||||
return uuid_objects
|
||||
|
||||
|
||||
def _apply_nodesets_to_datapoints(
|
||||
data_points: List[DataPoint],
|
||||
nodeset_map: Dict[str, Any],
|
||||
parent_doc_map: Dict[str, str]
|
||||
) -> None:
|
||||
"""Apply NodeSet values to DataPoints.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to update
|
||||
nodeset_map: Dictionary mapping document IDs to their NodeSet values
|
||||
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
|
||||
"""
|
||||
for dp in data_points:
|
||||
dp_id = str(dp.id)
|
||||
|
||||
# Try direct match first
|
||||
if dp_id in nodeset_map:
|
||||
nodeset = nodeset_map[dp_id]
|
||||
logger.info(f"Found NodeSet for {dp_id}: {nodeset}")
|
||||
dp.NodeSet = nodeset
|
||||
|
||||
# Then try parent document
|
||||
elif dp_id in parent_doc_map and parent_doc_map[dp_id] in nodeset_map:
|
||||
parent_id = parent_doc_map[dp_id]
|
||||
nodeset = nodeset_map[parent_id]
|
||||
logger.info(f"Found NodeSet from parent document {parent_id} for {dp_id}: {nodeset}")
|
||||
dp.NodeSet = nodeset
|
||||
|
||||
|
||||
async def _create_set_nodes(data_points: List[DataPoint]) -> None:
|
||||
"""Create set nodes for DataPoints with NodeSets.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to process
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Creating set nodes for {len(data_points)} DataPoints")
|
||||
|
||||
for dp in data_points:
|
||||
if not hasattr(dp, "NodeSet") or not dp.NodeSet:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create set nodes for the NodeSet (one per value)
|
||||
document_id = str(dp.id) if hasattr(dp, "id") else None
|
||||
set_node_ids, edge_ids = await create_set_node(dp.NodeSet, document_id=document_id)
|
||||
|
||||
if set_node_ids and len(set_node_ids) > 0:
|
||||
logger.info(f"Created {len(set_node_ids)} set nodes for NodeSet with {len(dp.NodeSet)} values")
|
||||
|
||||
# Store the set node IDs with the DataPoint if possible
|
||||
try:
|
||||
# Store as JSON string if multiple IDs, or single ID if only one
|
||||
if len(set_node_ids) > 1:
|
||||
dp.SetNodeIds = json.dumps(set_node_ids)
|
||||
else:
|
||||
dp.SetNodeId = set_node_ids[0]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store set node IDs for NodeSet: {str(e)}")
|
||||
else:
|
||||
logger.warning("Failed to create set nodes for NodeSet")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating set nodes: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing NodeSets: {str(e)}")
|
||||
|
||||
logger.info(f"Successfully applied NodeSet values to DataPoints")
|
||||
|
||||
async def create_set_node(
|
||||
nodes: Union[List[str], str],
|
||||
name: Optional[str] = None,
|
||||
document_id: Optional[str] = None
|
||||
) -> Tuple[Optional[List[str]], List[str]]:
|
||||
"""Create individual nodes for each value in the NodeSet.
|
||||
|
||||
Args:
|
||||
nodes: List of node values or JSON string representing node values
|
||||
name: Base name for the NodeSet (optional)
|
||||
document_id: ID of the document containing the NodeSet (optional)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of created set node IDs (or None if creation failed)
|
||||
- List of created edge IDs
|
||||
"""
|
||||
try:
|
||||
if not nodes:
|
||||
logger.warning("No nodes provided to create set nodes")
|
||||
return None, []
|
||||
|
||||
# Get the graph engine
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
# Parse nodes if provided as JSON string
|
||||
nodes_list = _parse_nodes_input(nodes)
|
||||
if not nodes_list:
|
||||
return None, []
|
||||
|
||||
# Base name for the set if not provided
|
||||
base_name = name or f"NodeSet_{uuid.uuid4().hex[:8]}"
|
||||
logger.info(f"Creating individual set nodes for {len(nodes_list)} values with base name '{base_name}'")
|
||||
|
||||
# Create set nodes using the appropriate strategy
|
||||
return await _create_set_nodes_unified(graph_engine, nodes_list, base_name, document_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create set nodes: {str(e)}")
|
||||
return None, []
|
||||
|
||||
|
||||
def _parse_nodes_input(nodes: Union[List[str], str]) -> Optional[List[str]]:
|
||||
"""Parse the nodes input.
|
||||
|
||||
Args:
|
||||
nodes: List of node values or JSON string representing node values
|
||||
|
||||
Returns:
|
||||
List of node values or None if parsing failed
|
||||
"""
|
||||
if isinstance(nodes, str):
|
||||
try:
|
||||
parsed_nodes = json.loads(nodes)
|
||||
logger.info(f"Parsed nodes string into list with {len(parsed_nodes)} items")
|
||||
return parsed_nodes
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse nodes as JSON: {str(e)}")
|
||||
return None
|
||||
return nodes
|
||||
|
||||
|
||||
async def _create_set_nodes_unified(
|
||||
graph_engine: Any,
|
||||
nodes_list: List[str],
|
||||
base_name: str,
|
||||
document_id: Optional[str]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Create set nodes using either NetworkX or generic graph engine.
|
||||
|
||||
Args:
|
||||
graph_engine: The graph engine instance
|
||||
nodes_list: List of node values
|
||||
base_name: Base name for the NodeSet
|
||||
document_id: ID of the document containing the NodeSet (optional)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of created set node IDs
|
||||
- List of created edge IDs
|
||||
"""
|
||||
all_set_node_ids = []
|
||||
all_edge_ids = []
|
||||
|
||||
# Define strategies for node and edge creation based on graph engine type
|
||||
if isinstance(graph_engine, NetworkXAdapter):
|
||||
# NetworkX-specific strategy
|
||||
async def create_node(node_value: str) -> str:
|
||||
set_node_id = str(uuid.uuid4())
|
||||
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
graph_engine.graph.add_node(
|
||||
set_node_id,
|
||||
id=set_node_id,
|
||||
type="NodeSet",
|
||||
name=node_name,
|
||||
node_id=node_value
|
||||
)
|
||||
|
||||
# Validate node creation
|
||||
if set_node_id in graph_engine.graph.nodes():
|
||||
node_props = dict(graph_engine.graph.nodes[set_node_id])
|
||||
logger.info(f"Created set node for value '{node_value}': {json.dumps(node_props)}")
|
||||
else:
|
||||
logger.warning(f"Node {set_node_id} not found in graph after adding")
|
||||
|
||||
return set_node_id
|
||||
|
||||
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
edge_id = str(uuid.uuid4())
|
||||
graph_engine.graph.add_edge(
|
||||
set_node_id,
|
||||
node_value,
|
||||
id=edge_id,
|
||||
type="CONTAINS"
|
||||
)
|
||||
edge_ids.append(edge_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
doc_to_nodeset_id = str(uuid.uuid4())
|
||||
graph_engine.graph.add_edge(
|
||||
document_id,
|
||||
set_node_id,
|
||||
id=doc_to_nodeset_id,
|
||||
type="HAS_NODESET"
|
||||
)
|
||||
edge_ids.append(doc_to_nodeset_id)
|
||||
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
# Finalize function for NetworkX
|
||||
async def finalize() -> None:
|
||||
await graph_engine.save_graph_to_file(graph_engine.filename)
|
||||
|
||||
else:
|
||||
# Generic graph engine strategy
|
||||
async def create_node(node_value: str) -> str:
|
||||
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
|
||||
set_node_props = {
|
||||
"name": node_name,
|
||||
"type": "NodeSet",
|
||||
"node_id": node_value
|
||||
}
|
||||
return await graph_engine.create_node(set_node_props)
|
||||
|
||||
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
edge_id = await graph_engine.create_edge(
|
||||
source_id=set_node_id,
|
||||
target_id=node_value,
|
||||
edge_type="CONTAINS"
|
||||
)
|
||||
edge_ids.append(edge_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
doc_to_nodeset_id = await graph_engine.create_edge(
|
||||
source_id=document_id,
|
||||
target_id=set_node_id,
|
||||
edge_type="HAS_NODESET"
|
||||
)
|
||||
edge_ids.append(doc_to_nodeset_id)
|
||||
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
# Finalize function for generic engine (no-op)
|
||||
async def finalize() -> None:
|
||||
pass
|
||||
|
||||
# Unified process for both engine types
|
||||
for node_value in nodes_list:
|
||||
try:
|
||||
# Create the node
|
||||
set_node_id = await create_node(node_value)
|
||||
|
||||
# Create edges to the value
|
||||
value_edge_ids = await create_value_edge(set_node_id, node_value)
|
||||
all_edge_ids.extend(value_edge_ids)
|
||||
|
||||
# Create edges to the document if provided
|
||||
if document_id:
|
||||
doc_edge_ids = await create_document_edge(document_id, set_node_id)
|
||||
all_edge_ids.extend(doc_edge_ids)
|
||||
|
||||
all_set_node_ids.append(set_node_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create set node for value '{node_value}': {str(e)}")
|
||||
|
||||
# Finalize the process
|
||||
await finalize()
|
||||
|
||||
# Return results
|
||||
if all_set_node_ids:
|
||||
logger.info(f"Created {len(all_set_node_ids)} individual set nodes with values: {nodes_list}")
|
||||
return all_set_node_ids, all_edge_ids
|
||||
else:
|
||||
logger.error("Failed to create any set nodes")
|
||||
return [], []
|
||||
|
|
|
|||
675
enhanced_nodeset_visualization.py
Normal file
675
enhanced_nodeset_visualization.py
Normal file
|
|
@ -0,0 +1,675 @@
|
|||
import asyncio
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Tuple, Any, Optional, Set
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
# Configure logger
|
||||
logger = get_logger(name="enhanced_graph_visualization")
|
||||
|
||||
# Type aliases for clarity
|
||||
NodeData = Dict[str, Any]
|
||||
EdgeData = Dict[str, Any]
|
||||
GraphData = Tuple[List[Tuple[Any, Dict]], List[Tuple[Any, Any, Optional[str], Dict]]]
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder to handle datetime objects."""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class NodeSetVisualizer:
|
||||
"""Class to create enhanced visualizations for NodeSet data in knowledge graphs."""
|
||||
|
||||
# Color mapping for different node types
|
||||
NODE_COLORS = {
|
||||
"Entity": "#f47710",
|
||||
"EntityType": "#6510f4",
|
||||
"DocumentChunk": "#801212",
|
||||
"TextDocument": "#a83232", # Darker red for documents
|
||||
"TextSummary": "#1077f4",
|
||||
"NodeSet": "#ff00ff", # Bright magenta for NodeSet nodes
|
||||
"Unknown": "#999999",
|
||||
"default": "#D3D3D3",
|
||||
}
|
||||
|
||||
# Size mapping for different node types
|
||||
NODE_SIZES = {
|
||||
"NodeSet": 20, # Larger size for NodeSet nodes
|
||||
"TextDocument": 18, # Larger size for document nodes
|
||||
"DocumentChunk": 18, # Larger size for document nodes
|
||||
"TextSummary": 16, # Medium size for TextSummary nodes
|
||||
"default": 13, # Default size
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the visualizer."""
|
||||
self.graph_engine = None
|
||||
self.nodes_data = []
|
||||
self.edges_data = []
|
||||
self.node_count = 0
|
||||
self.edge_count = 0
|
||||
self.nodeset_count = 0
|
||||
|
||||
async def get_graph_data(self) -> bool:
|
||||
"""Fetch graph data from the graph engine.
|
||||
|
||||
Returns:
|
||||
bool: True if data was successfully retrieved, False otherwise.
|
||||
"""
|
||||
self.graph_engine = await cognee.infrastructure.databases.graph.get_graph_engine()
|
||||
|
||||
# Check if the graph exists and has nodes
|
||||
self.node_count = len(self.graph_engine.graph.nodes())
|
||||
self.edge_count = len(self.graph_engine.graph.edges())
|
||||
logger.info(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
|
||||
print(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
|
||||
|
||||
if self.node_count == 0:
|
||||
logger.error("The graph is empty! Please run a test script first to generate data.")
|
||||
print("ERROR: The graph is empty! Please run a test script first to generate data.")
|
||||
return False
|
||||
|
||||
graph_data = await self.graph_engine.get_graph_data()
|
||||
self.nodes_data, self.edges_data = graph_data
|
||||
|
||||
# Count NodeSets for status display
|
||||
self.nodeset_count = sum(1 for _, info in self.nodes_data if info.get("type") == "NodeSet")
|
||||
|
||||
return True
|
||||
|
||||
def prepare_node_data(self) -> List[NodeData]:
|
||||
"""Process raw node data to prepare for visualization.
|
||||
|
||||
Returns:
|
||||
List[NodeData]: List of prepared node data objects.
|
||||
"""
|
||||
nodes_list = []
|
||||
|
||||
# Create a lookup for node types for faster access
|
||||
node_type_lookup = {str(node_id): node_info.get("type", "Unknown")
|
||||
for node_id, node_info in self.nodes_data}
|
||||
|
||||
for node_id, node_info in self.nodes_data:
|
||||
# Create a clean copy to avoid modifying the original
|
||||
processed_node = node_info.copy()
|
||||
|
||||
# Remove fields that cause JSON serialization issues
|
||||
self._clean_node_data(processed_node)
|
||||
|
||||
# Add required visualization properties
|
||||
processed_node["id"] = str(node_id)
|
||||
node_type = processed_node.get("type", "default")
|
||||
|
||||
# Apply visual styling based on node type
|
||||
processed_node["color"] = self.NODE_COLORS.get(node_type, self.NODE_COLORS["default"])
|
||||
processed_node["size"] = self.NODE_SIZES.get(node_type, self.NODE_SIZES["default"])
|
||||
|
||||
# Create display names
|
||||
self._format_node_display_name(processed_node, node_type)
|
||||
|
||||
nodes_list.append(processed_node)
|
||||
|
||||
return nodes_list
|
||||
|
||||
@staticmethod
|
||||
def _clean_node_data(node: NodeData) -> None:
|
||||
"""Remove fields that might cause JSON serialization issues.
|
||||
|
||||
Args:
|
||||
node: The node data to clean
|
||||
"""
|
||||
# Remove non-essential fields that might cause serialization issues
|
||||
for key in ["created_at", "updated_at", "raw_data_location"]:
|
||||
if key in node:
|
||||
del node[key]
|
||||
|
||||
@staticmethod
|
||||
def _format_node_display_name(node: NodeData, node_type: str) -> None:
|
||||
"""Format the display name for a node.
|
||||
|
||||
Args:
|
||||
node: The node data to process
|
||||
node_type: The type of the node
|
||||
"""
|
||||
# Set a default name if none exists
|
||||
node["name"] = node.get("name", node.get("id", "Unknown"))
|
||||
|
||||
# Special formatting for NodeSet nodes
|
||||
if node_type == "NodeSet" and "node_id" in node:
|
||||
node["display_name"] = f"NodeSet: {node['node_id']}"
|
||||
else:
|
||||
node["display_name"] = node["name"]
|
||||
|
||||
# Truncate long display names
|
||||
if len(node["display_name"]) > 30:
|
||||
node["display_name"] = f"{node['display_name'][:27]}..."
|
||||
|
||||
def prepare_edge_data(self, nodes_list: List[NodeData]) -> List[EdgeData]:
|
||||
"""Process raw edge data to prepare for visualization.
|
||||
|
||||
Args:
|
||||
nodes_list: The processed node data
|
||||
|
||||
Returns:
|
||||
List[EdgeData]: List of prepared edge data objects.
|
||||
"""
|
||||
links_list = []
|
||||
|
||||
# Create a lookup for node types for faster access
|
||||
node_type_lookup = {node["id"]: node.get("type", "Unknown") for node in nodes_list}
|
||||
|
||||
for source, target, relation, edge_info in self.edges_data:
|
||||
source_str = str(source)
|
||||
target_str = str(target)
|
||||
|
||||
# Skip if source or target not in node_type_lookup (should not happen)
|
||||
if source_str not in node_type_lookup or target_str not in node_type_lookup:
|
||||
continue
|
||||
|
||||
# Get node types
|
||||
source_type = node_type_lookup[source_str]
|
||||
target_type = node_type_lookup[target_str]
|
||||
|
||||
# Create edge data
|
||||
link_data = {
|
||||
"source": source_str,
|
||||
"target": target_str,
|
||||
"relation": relation or "UNKNOWN"
|
||||
}
|
||||
|
||||
# Categorize the edge for styling
|
||||
link_data["connection_type"] = self._categorize_edge(source_type, target_type)
|
||||
|
||||
links_list.append(link_data)
|
||||
|
||||
return links_list
|
||||
|
||||
@staticmethod
|
||||
def _categorize_edge(source_type: str, target_type: str) -> str:
|
||||
"""Categorize an edge based on the connected node types.
|
||||
|
||||
Args:
|
||||
source_type: The type of the source node
|
||||
target_type: The type of the target node
|
||||
|
||||
Returns:
|
||||
str: The category of the edge
|
||||
"""
|
||||
if source_type == "NodeSet" and target_type != "NodeSet":
|
||||
return "nodeset_to_value"
|
||||
elif (source_type in ["TextDocument", "DocumentChunk", "TextSummary"]) and target_type == "NodeSet":
|
||||
return "document_to_nodeset"
|
||||
elif target_type == "NodeSet":
|
||||
return "to_nodeset"
|
||||
elif source_type == "NodeSet":
|
||||
return "from_nodeset"
|
||||
else:
|
||||
return "standard"
|
||||
|
||||
def generate_html(self, nodes_list: List[NodeData], links_list: List[EdgeData]) -> str:
|
||||
"""Generate the HTML visualization with D3.js.
|
||||
|
||||
Args:
|
||||
nodes_list: The processed node data
|
||||
links_list: The processed edge data
|
||||
|
||||
Returns:
|
||||
str: The HTML content for the visualization
|
||||
"""
|
||||
# Use embedded template directly - more reliable than file access
|
||||
html_template = self._get_embedded_html_template()
|
||||
|
||||
# Generate the HTML content with custom JSON encoder for datetime objects
|
||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, cls=DateTimeEncoder))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list, cls=DateTimeEncoder))
|
||||
html_content = html_content.replace("{node_count}", str(self.node_count))
|
||||
html_content = html_content.replace("{edge_count}", str(self.edge_count))
|
||||
html_content = html_content.replace("{nodeset_count}", str(self.nodeset_count))
|
||||
|
||||
return html_content
|
||||
|
||||
def save_html(self, html_content: str) -> str:
|
||||
"""Save the HTML content to a file and open it in the browser.
|
||||
|
||||
Args:
|
||||
html_content: The HTML content to save
|
||||
|
||||
Returns:
|
||||
str: The path to the saved file
|
||||
"""
|
||||
# Create the output file path
|
||||
output_path = Path.cwd() / "enhanced_nodeset_visualization.html"
|
||||
|
||||
# Write the HTML content to the file
|
||||
with open(output_path, "w") as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f"Enhanced visualization saved to: {output_path}")
|
||||
print(f"Enhanced visualization saved to: {output_path}")
|
||||
|
||||
# Open the visualization in the default web browser
|
||||
file_url = f"file://{output_path}"
|
||||
logger.info(f"Opening visualization in browser: {file_url}")
|
||||
print(f"Opening enhanced visualization in browser: {file_url}")
|
||||
webbrowser.open(file_url)
|
||||
|
||||
return str(output_path)
|
||||
|
||||
@staticmethod
|
||||
def _get_embedded_html_template() -> str:
|
||||
"""Get the embedded HTML template as a fallback.
|
||||
|
||||
Returns:
|
||||
str: The HTML template
|
||||
"""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Cognee NodeSet Visualization</title>
|
||||
<script src="https://d3js.org/d3.v5.min.js"></script>
|
||||
<style>
|
||||
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', Arial, sans-serif; }
|
||||
|
||||
svg { width: 100vw; height: 100vh; display: block; }
|
||||
.links line { stroke-width: 2px; }
|
||||
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
|
||||
.node-label { font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
|
||||
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
|
||||
|
||||
/* NodeSet specific styles */
|
||||
.links line.nodeset_to_value { stroke: #ff00ff; stroke-width: 3px; stroke-dasharray: 5, 5; }
|
||||
.links line.document_to_nodeset { stroke: #fc0; stroke-width: 3px; }
|
||||
.links line.to_nodeset { stroke: #0cf; stroke-width: 2px; }
|
||||
.links line.from_nodeset { stroke: #0fc; stroke-width: 2px; }
|
||||
|
||||
.nodes circle.nodeset { stroke: white; stroke-width: 2px; filter: drop-shadow(0 0 10px rgba(255,0,255,0.8)); }
|
||||
.nodes circle.document { stroke: white; stroke-width: 1.5px; filter: drop-shadow(0 0 8px rgba(255,255,0,0.6)); }
|
||||
.node-label.nodeset { font-size: 6px; font-weight: bold; fill: white; }
|
||||
.node-label.document { font-size: 5.5px; font-weight: bold; fill: white; }
|
||||
|
||||
/* Legend */
|
||||
.legend { position: fixed; top: 10px; left: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
|
||||
.legend-item { display: flex; align-items: center; margin-bottom: 5px; }
|
||||
.legend-color { width: 15px; height: 15px; margin-right: 10px; border-radius: 50%; }
|
||||
.legend-label { font-size: 14px; }
|
||||
|
||||
/* Edge legend */
|
||||
.edge-legend { position: fixed; top: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
|
||||
.edge-legend-item { display: flex; align-items: center; margin-bottom: 10px; }
|
||||
.edge-line { width: 30px; height: 3px; margin-right: 10px; }
|
||||
.edge-label { font-size: 14px; }
|
||||
|
||||
/* Controls */
|
||||
.controls { position: fixed; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; display: flex; gap: 10px; }
|
||||
button { background: #333; color: white; border: none; padding: 5px 10px; border-radius: 4px; cursor: pointer; }
|
||||
button:hover { background: #555; }
|
||||
|
||||
/* Status message */
|
||||
.status { position: fixed; bottom: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<svg></svg>
|
||||
|
||||
<!-- Node Legend -->
|
||||
<div class="legend">
|
||||
<h3 style="margin-top: 0;">Node Types</h3>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #ff00ff;"></div>
|
||||
<div class="legend-label">NodeSet</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #a83232;"></div>
|
||||
<div class="legend-label">TextDocument</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #801212;"></div>
|
||||
<div class="legend-label">DocumentChunk</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #1077f4;"></div>
|
||||
<div class="legend-label">TextSummary</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #f47710;"></div>
|
||||
<div class="legend-label">Entity</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #6510f4;"></div>
|
||||
<div class="legend-label">EntityType</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #999999;"></div>
|
||||
<div class="legend-label">Unknown</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Edge Legend -->
|
||||
<div class="edge-legend">
|
||||
<h3 style="margin-top: 0;">Edge Types</h3>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: #fc0;"></div>
|
||||
<div class="edge-label">Document → NodeSet</div>
|
||||
</div>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: #ff00ff; height: 3px; background: linear-gradient(to right, #ff00ff 50%, transparent 50%); background-size: 10px 3px; background-repeat: repeat-x;"></div>
|
||||
<div class="edge-label">NodeSet → Value</div>
|
||||
</div>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: #0cf;"></div>
|
||||
<div class="edge-label">Any → NodeSet</div>
|
||||
</div>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: rgba(255, 255, 255, 0.4);"></div>
|
||||
<div class="edge-label">Standard Connection</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Controls -->
|
||||
<div class="controls">
|
||||
<button id="center-btn">Center Graph</button>
|
||||
<button id="highlight-nodesets">Highlight NodeSets</button>
|
||||
<button id="highlight-documents">Highlight Documents</button>
|
||||
<button id="reset-highlight">Reset Highlight</button>
|
||||
</div>
|
||||
|
||||
<!-- Status -->
|
||||
<div class="status">
|
||||
<div>Nodes: {node_count}</div>
|
||||
<div>Edges: {edge_count}</div>
|
||||
<div>NodeSets: {nodeset_count}</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
var nodes = {nodes};
|
||||
var links = {links};
|
||||
|
||||
var svg = d3.select("svg"),
|
||||
width = window.innerWidth,
|
||||
height = window.innerHeight;
|
||||
|
||||
var container = svg.append("g");
|
||||
|
||||
// Count NodeSets for status display
|
||||
const nodesetCount = nodes.filter(n => n.type === "NodeSet").length;
|
||||
document.querySelector('.status').innerHTML = `
|
||||
<div>Nodes: ${nodes.length}</div>
|
||||
<div>Edges: ${links.length}</div>
|
||||
<div>NodeSets: ${nodesetCount}</div>
|
||||
`;
|
||||
|
||||
var simulation = d3.forceSimulation(nodes)
|
||||
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
|
||||
.force("charge", d3.forceManyBody().strength(-300))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2))
|
||||
.force("x", d3.forceX().strength(0.1).x(width / 2))
|
||||
.force("y", d3.forceY().strength(0.1).y(height / 2));
|
||||
|
||||
var link = container.append("g")
|
||||
.attr("class", "links")
|
||||
.selectAll("line")
|
||||
.data(links)
|
||||
.enter().append("line")
|
||||
.attr("stroke-width", d => {
|
||||
if (d.connection_type === 'document_to_nodeset' || d.connection_type === 'nodeset_to_value') {
|
||||
return 3;
|
||||
}
|
||||
return 2;
|
||||
})
|
||||
.attr("stroke", d => {
|
||||
switch(d.connection_type) {
|
||||
case 'document_to_nodeset': return "#fc0";
|
||||
case 'nodeset_to_value': return "#ff00ff";
|
||||
case 'to_nodeset': return "#0cf";
|
||||
case 'from_nodeset': return "#0fc";
|
||||
default: return "rgba(255, 255, 255, 0.4)";
|
||||
}
|
||||
})
|
||||
.attr("stroke-dasharray", d => d.connection_type === 'nodeset_to_value' ? "5,5" : null)
|
||||
.attr("class", d => d.connection_type);
|
||||
|
||||
var edgeLabels = container.append("g")
|
||||
.attr("class", "edge-labels")
|
||||
.selectAll("text")
|
||||
.data(links)
|
||||
.enter().append("text")
|
||||
.attr("class", "edge-label")
|
||||
.text(d => d.relation);
|
||||
|
||||
var nodeGroup = container.append("g")
|
||||
.attr("class", "nodes")
|
||||
.selectAll("g")
|
||||
.data(nodes)
|
||||
.enter().append("g");
|
||||
|
||||
var node = nodeGroup.append("circle")
|
||||
.attr("r", d => d.size || 13)
|
||||
.attr("fill", d => d.color)
|
||||
.attr("class", d => {
|
||||
if (d.type === "NodeSet") return "nodeset";
|
||||
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "document";
|
||||
return "";
|
||||
})
|
||||
.call(d3.drag()
|
||||
.on("start", dragstarted)
|
||||
.on("drag", dragged)
|
||||
.on("end", dragended));
|
||||
|
||||
nodeGroup.append("text")
|
||||
.attr("class", d => {
|
||||
if (d.type === "NodeSet") return "node-label nodeset";
|
||||
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "node-label document";
|
||||
return "node-label";
|
||||
})
|
||||
.attr("dy", 4)
|
||||
.attr("font-size", d => {
|
||||
if (d.type === "NodeSet") return "6px";
|
||||
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "5.5px";
|
||||
return "5px";
|
||||
})
|
||||
.attr("text-anchor", "middle")
|
||||
.text(d => d.display_name || d.name);
|
||||
|
||||
node.append("title").text(d => {
|
||||
// Create a formatted tooltip with node properties
|
||||
let props = Object.entries(d)
|
||||
.filter(([key]) => !["x", "y", "vx", "vy", "index", "fx", "fy", "color", "display_name"].includes(key))
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join("\\n");
|
||||
return props;
|
||||
});
|
||||
|
||||
simulation.on("tick", function() {
|
||||
link.attr("x1", d => d.source.x)
|
||||
.attr("y1", d => d.source.y)
|
||||
.attr("x2", d => d.target.x)
|
||||
.attr("y2", d => d.target.y);
|
||||
|
||||
edgeLabels
|
||||
.attr("x", d => (d.source.x + d.target.x) / 2)
|
||||
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
|
||||
|
||||
node.attr("cx", d => d.x)
|
||||
.attr("cy", d => d.y);
|
||||
|
||||
nodeGroup.select("text")
|
||||
.attr("x", d => d.x)
|
||||
.attr("y", d => d.y)
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle");
|
||||
});
|
||||
|
||||
// Add zoom behavior
|
||||
const zoom = d3.zoom()
|
||||
.scaleExtent([0.1, 8])
|
||||
.on("zoom", function() {
|
||||
container.attr("transform", d3.event.transform);
|
||||
});
|
||||
|
||||
svg.call(zoom);
|
||||
|
||||
// Button controls
|
||||
document.getElementById("center-btn").addEventListener("click", function() {
|
||||
svg.transition().duration(750).call(
|
||||
zoom.transform,
|
||||
d3.zoomIdentity.translate(width / 2, height / 2).scale(1)
|
||||
);
|
||||
});
|
||||
|
||||
document.getElementById("highlight-nodesets").addEventListener("click", function() {
|
||||
highlightNodes("NodeSet");
|
||||
});
|
||||
|
||||
document.getElementById("highlight-documents").addEventListener("click", function() {
|
||||
highlightNodes(["TextDocument", "DocumentChunk"]);
|
||||
});
|
||||
|
||||
document.getElementById("reset-highlight").addEventListener("click", function() {
|
||||
resetHighlight();
|
||||
});
|
||||
|
||||
function highlightNodes(typeToHighlight) {
|
||||
// Dim all nodes and links
|
||||
node.transition().duration(300)
|
||||
.attr("opacity", 0.2);
|
||||
link.transition().duration(300)
|
||||
.attr("opacity", 0.2);
|
||||
nodeGroup.selectAll("text").transition().duration(300)
|
||||
.attr("opacity", 0.2);
|
||||
|
||||
// Create arrays for types if a single string is provided
|
||||
const typesToHighlight = Array.isArray(typeToHighlight) ? typeToHighlight : [typeToHighlight];
|
||||
|
||||
// Highlight matching nodes and their connected nodes
|
||||
const highlightedNodeIds = new Set();
|
||||
|
||||
// First, find all nodes of the target type
|
||||
nodes.forEach(n => {
|
||||
if (typesToHighlight.includes(n.type)) {
|
||||
highlightedNodeIds.add(n.id);
|
||||
}
|
||||
});
|
||||
|
||||
// Find all connected nodes (both directions)
|
||||
links.forEach(l => {
|
||||
if (highlightedNodeIds.has(l.source.id || l.source)) {
|
||||
highlightedNodeIds.add(l.target.id || l.target);
|
||||
}
|
||||
if (highlightedNodeIds.has(l.target.id || l.target)) {
|
||||
highlightedNodeIds.add(l.source.id || l.source);
|
||||
}
|
||||
});
|
||||
|
||||
// Highlight the nodes
|
||||
node.filter(d => highlightedNodeIds.has(d.id))
|
||||
.transition().duration(300)
|
||||
.attr("opacity", 1);
|
||||
|
||||
// Highlight the labels
|
||||
nodeGroup.selectAll("text")
|
||||
.filter(d => highlightedNodeIds.has(d.id))
|
||||
.transition().duration(300)
|
||||
.attr("opacity", 1);
|
||||
|
||||
// Highlight the links between highlighted nodes
|
||||
link.filter(d => {
|
||||
const sourceId = d.source.id || d.source;
|
||||
const targetId = d.target.id || d.target;
|
||||
return highlightedNodeIds.has(sourceId) && highlightedNodeIds.has(targetId);
|
||||
})
|
||||
.transition().duration(300)
|
||||
.attr("opacity", 1);
|
||||
}
|
||||
|
||||
function resetHighlight() {
|
||||
node.transition().duration(300).attr("opacity", 1);
|
||||
link.transition().duration(300).attr("opacity", 1);
|
||||
nodeGroup.selectAll("text").transition().duration(300).attr("opacity", 1);
|
||||
}
|
||||
|
||||
function dragstarted(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x;
|
||||
d.fy = d.y;
|
||||
}
|
||||
|
||||
function dragged(d) {
|
||||
d.fx = d3.event.x;
|
||||
d.fy = d3.event.y;
|
||||
}
|
||||
|
||||
function dragended(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0);
|
||||
d.fx = null;
|
||||
d.fy = null;
|
||||
}
|
||||
|
||||
window.addEventListener("resize", function() {
|
||||
width = window.innerWidth;
|
||||
height = window.innerHeight;
|
||||
svg.attr("width", width).attr("height", height);
|
||||
simulation.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
simulation.alpha(1).restart();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
async def create_visualization(self) -> Optional[str]:
|
||||
"""Main method to create the visualization.
|
||||
|
||||
Returns:
|
||||
Optional[str]: Path to the saved visualization file, or None if unsuccessful
|
||||
"""
|
||||
print("Creating enhanced NodeSet visualization...")
|
||||
|
||||
# Get graph data
|
||||
if not await self.get_graph_data():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Process nodes
|
||||
nodes_list = self.prepare_node_data()
|
||||
|
||||
# Process edges
|
||||
links_list = self.prepare_edge_data(nodes_list)
|
||||
|
||||
# Generate HTML
|
||||
html_content = self.generate_html(nodes_list, links_list)
|
||||
|
||||
# Save to file and open in browser
|
||||
return self.save_html(html_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating visualization: {str(e)}")
|
||||
print(f"Error creating visualization: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the script."""
|
||||
visualizer = NodeSetVisualizer()
|
||||
await visualizer.create_visualization()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the async main function
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
Loading…
Add table
Reference in a new issue