Removed files

This commit is contained in:
vasilije 2025-04-17 17:21:09 +02:00
parent 1c40a5081a
commit 69c090c91d
9 changed files with 1 additions and 1605 deletions

View file

@ -13,7 +13,7 @@ from cognee.modules.data.models import Data, Dataset
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
@ -25,7 +25,6 @@ from cognee.tasks.documents import (
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
from cognee.tasks.node_set import apply_node_set
from cognee.modules.chunking.TextChunker import TextChunker
logger = get_logger("cognify")
@ -140,7 +139,6 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values and create set nodes
]
return default_tasks

View file

@ -1,161 +0,0 @@
import os
import uuid
import logging
import json
from typing import List
import httpx
import asyncio
import cognee
from cognee.api.v1.search import SearchType
from cognee.shared.logging_utils import get_logger, ERROR
# Replace incorrect import with proper cognee.prune module
# from cognee.infrastructure.data_storage import reset_cognee
# Remove incorrect imports and use high-level API
# from cognee.services.cognify import cognify
# from cognee.services.search import search
from cognee.tasks.node_set.apply_node_set import apply_node_set
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Data
# Configure logging to see detailed output
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up httpx client for debugging
httpx.Client(timeout=None)
async def generate_test_node_ids(count: int) -> List[str]:
"""Generate unique node IDs for testing."""
return [str(uuid.uuid4()) for _ in range(count)]
async def add_text_with_nodeset(text: str, node_set: List[str], text_id: str = None) -> str:
"""Add text with NodeSet to cognee."""
# Generate text ID if not provided - but note we can't directly set the ID
if text_id is None:
text_id = str(uuid.uuid4())
# Print NodeSet details
logger.info("Adding text with NodeSet")
logger.info(f"NodeSet for this text: {node_set}")
# Use high-level cognee.add() with the correct parameters
await cognee.add(text, node_set=node_set)
logger.info("Saved text with NodeSet to database")
return text_id # Note: we can't control the actual ID generated by cognee.add()
async def check_data_records():
"""Check for data records in the database to verify NodeSet storage."""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from sqlalchemy import select
query = select(Data)
result = await session.execute(query)
records = result.scalars().all()
logger.info(f"Found {len(records)} records in the database")
for record in records:
logger.info(f"Record ID: {record.id}, name: {record.name}, node_set: {record.node_set}")
# Print raw_data_location to see where the text content is stored
logger.info(f"Raw data location: {record.raw_data_location}")
async def run_simple_node_set_test():
"""Run a simple test of NodeSet functionality."""
try:
# Reset cognee data using correct module
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
logger.info("Cognee data reset complete")
# Generate test node IDs for two NodeSets
# nodeset1 = await generate_test_node_ids(3)
nodeset2 = await generate_test_node_ids(3)
nodeset1 = [
"my_accounting_data",
"my_set_of_manuals_about_horses",
"my_elon_musk_secret_file",
]
logger.info(f"Created test node IDs for NodeSet 1: {nodeset1}")
logger.info(f"Created test node IDs for NodeSet 2: {nodeset2}")
# Add two texts with NodeSets
text1 = "Natural Language Processing (NLP) is a field of artificial intelligence focused on enabling computers to understand and process human language. It involves techniques for language analysis, translation, and generation."
text2 = "Artificial Intelligence (AI) systems are designed to perform tasks that typically require human intelligence. These tasks include speech recognition, decision-making, and language translation."
text1_id = await add_text_with_nodeset(text1, nodeset1)
text2_id = await add_text_with_nodeset(text2, nodeset2)
logger.info(str(text1_id))
logger.info(str(text2_id))
# Verify that the NodeSets were stored correctly
await check_data_records()
# Run the cognify process to create a knowledge graph
pipeline_run_id = await cognee.cognify()
logger.info(f"Cognify process completed with pipeline run ID: {pipeline_run_id}")
# Skip graph search as the NetworkXAdapter doesn't support Cypher
logger.info("Skipping graph search as NetworkXAdapter doesn't support Cypher")
# Search for insights related to the added texts
search_query = "NLP and AI"
logger.info(f"Searching for insights with query: '{search_query}'")
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=search_query
)
# Extract NodeSet information from search results
logger.info(f"Found {len(search_results)} search results for '{search_query}':")
for i, result in enumerate(search_results):
logger.info(f"Result {i + 1} text: {getattr(result, 'text', 'No text')}")
# Check for NodeSet and SetNodeId
node_set = getattr(result, "NodeSet", None)
set_node_id = getattr(result, "SetNodeId", None)
logger.info(f"Result {i + 1} - NodeSet: {node_set}, SetNodeId: {set_node_id}")
# Check id and type
logger.info(
f"Result {i + 1} - ID: {getattr(result, 'id', 'No ID')}, Type: {getattr(result, 'type', 'No type')}"
)
# Check if this is a document chunk and has is_part_of property
if hasattr(result, "is_part_of") and result.is_part_of:
logger.info(
f"Result {i + 1} is a document chunk with parent ID: {result.is_part_of.id}"
)
# Check if the parent has a NodeSet
parent_has_nodeset = (
hasattr(result.is_part_of, "NodeSet") and result.is_part_of.NodeSet
)
logger.info(f" Parent has NodeSet: {parent_has_nodeset}")
if parent_has_nodeset:
logger.info(f" Parent NodeSet: {result.is_part_of.NodeSet}")
# Print all attributes of the result to see what's available
logger.info(f"Result {i + 1} - All attributes: {dir(result)}")
except Exception as e:
logger.error(f"Error in simple NodeSet test: {e}")
raise
if __name__ == "__main__":
logger = get_logger(level=ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_simple_node_set_test())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -1 +0,0 @@
from .apply_node_set import apply_node_set

View file

@ -1,451 +0,0 @@
import uuid
import json
import logging
from typing import List, Dict, Any, Optional, Union, Tuple, Sequence, Protocol, Callable
from contextlib import asynccontextmanager
from cognee.shared.logging_utils import get_logger
from sqlalchemy.future import select
from sqlalchemy import or_
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.modules.data.models import Data
from cognee.infrastructure.engine.models.DataPoint import DataPoint
# Configure logger
logger = get_logger(name="apply_node_set")
async def apply_node_set(data: Union[DataPoint, List[DataPoint]]) -> Union[DataPoint, List[DataPoint]]:
"""Apply NodeSet values to DataPoint objects.
Args:
data: Single DataPoint or list of DataPoints to process
Returns:
The processed DataPoint(s) with updated NodeSet values
"""
if not data:
logger.warning("No data provided to apply NodeSet values")
return data
# Convert single DataPoint to list for uniform processing
data_points = data if isinstance(data, list) else [data]
logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints")
# Process DataPoint objects to apply NodeSet values
updated_data_points = await _process_data_points(data_points)
# Create set nodes for each NodeSet
await _create_set_nodes(updated_data_points)
# Return data in the same format it was received
return data_points if isinstance(data, list) else data_points[0]
async def _process_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
"""Process DataPoint objects to apply NodeSet values from the database.
Args:
data_points: List of DataPoint objects to process
Returns:
The processed list of DataPoints with updated NodeSet values
"""
try:
if not data_points:
return []
# Extract IDs and collect document relationships
data_point_ids, parent_doc_map = _collect_ids_and_relationships(data_points)
# Get NodeSet values from database
nodeset_map = await _fetch_nodesets_from_database(data_point_ids, parent_doc_map)
# Apply NodeSet values to DataPoints
_apply_nodesets_to_datapoints(data_points, nodeset_map, parent_doc_map)
return data_points
except Exception as e:
logger.error(f"Error processing DataPoints: {str(e)}")
return data_points
def _collect_ids_and_relationships(data_points: List[DataPoint]) -> Tuple[List[str], Dict[str, str]]:
"""Extract DataPoint IDs and document relationships.
Args:
data_points: List of DataPoint objects
Returns:
Tuple containing:
- List of DataPoint IDs
- Dictionary mapping DataPoint IDs to parent document IDs
"""
data_point_ids = []
parent_doc_ids = []
parent_doc_map = {}
# Collect all IDs to look up
for dp in data_points:
# Get the DataPoint ID
if hasattr(dp, "id"):
dp_id = str(dp.id)
data_point_ids.append(dp_id)
# Check if there's a parent document to get NodeSet from
if (hasattr(dp, "made_from") and
hasattr(dp.made_from, "is_part_of") and
hasattr(dp.made_from.is_part_of, "id")):
parent_id = str(dp.made_from.is_part_of.id)
parent_doc_ids.append(parent_id)
parent_doc_map[dp_id] = parent_id
logger.info(f"Found {len(data_point_ids)} DataPoint IDs and {len(parent_doc_ids)} parent document IDs")
# Combine all IDs for database lookup
return data_point_ids + parent_doc_ids, parent_doc_map
async def _fetch_nodesets_from_database(ids: List[str], parent_doc_map: Dict[str, str]) -> Dict[str, Any]:
"""Fetch NodeSet values from the database for the given IDs.
Args:
ids: List of IDs to search for
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
Returns:
Dictionary mapping document IDs to their NodeSet values
"""
# Convert string IDs to UUIDs for database lookup
uuid_objects = _convert_ids_to_uuids(ids)
if not uuid_objects:
return {}
# Query the database for NodeSet values
nodeset_map = {}
db_engine = get_relational_engine()
async with db_engine.get_async_session() as sess:
# Query for records matching the IDs
query = select(Data).where(Data.id.in_(uuid_objects))
result = await sess.execute(query)
records = result.scalars().all()
logger.info(f"Found {len(records)} total records matching the IDs")
# Extract NodeSet values from records
for record in records:
if record.node_set is not None:
nodeset_map[str(record.id)] = record.node_set
logger.info(f"Found {len(nodeset_map)} records with non-NULL NodeSet values")
return nodeset_map
def _convert_ids_to_uuids(ids: List[str]) -> List[uuid.UUID]:
"""Convert string IDs to UUID objects.
Args:
ids: List of string IDs to convert
Returns:
List of UUID objects
"""
uuid_objects = []
for id_str in ids:
try:
uuid_objects.append(uuid.UUID(id_str))
except Exception as e:
logger.warning(f"Failed to convert ID {id_str} to UUID: {str(e)}")
logger.info(f"Converted {len(uuid_objects)} out of {len(ids)} IDs to UUID objects")
return uuid_objects
def _apply_nodesets_to_datapoints(
data_points: List[DataPoint],
nodeset_map: Dict[str, Any],
parent_doc_map: Dict[str, str]
) -> None:
"""Apply NodeSet values to DataPoints.
Args:
data_points: List of DataPoint objects to update
nodeset_map: Dictionary mapping document IDs to their NodeSet values
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
"""
for dp in data_points:
dp_id = str(dp.id)
# Try direct match first
if dp_id in nodeset_map:
nodeset = nodeset_map[dp_id]
logger.info(f"Found NodeSet for {dp_id}: {nodeset}")
dp.NodeSet = nodeset
# Then try parent document
elif dp_id in parent_doc_map and parent_doc_map[dp_id] in nodeset_map:
parent_id = parent_doc_map[dp_id]
nodeset = nodeset_map[parent_id]
logger.info(f"Found NodeSet from parent document {parent_id} for {dp_id}: {nodeset}")
dp.NodeSet = nodeset
async def _create_set_nodes(data_points: List[DataPoint]) -> None:
"""Create set nodes for DataPoints with NodeSets.
Args:
data_points: List of DataPoint objects to process
"""
try:
logger.info(f"Creating set nodes for {len(data_points)} DataPoints")
for dp in data_points:
if not hasattr(dp, "NodeSet") or not dp.NodeSet:
continue
try:
# Create set nodes for the NodeSet (one per value)
document_id = str(dp.id) if hasattr(dp, "id") else None
set_node_ids, edge_ids = await create_set_node(dp.NodeSet, document_id=document_id)
if set_node_ids and len(set_node_ids) > 0:
logger.info(f"Created {len(set_node_ids)} set nodes for NodeSet with {len(dp.NodeSet)} values")
# Store the set node IDs with the DataPoint if possible
try:
# Store as JSON string if multiple IDs, or single ID if only one
if len(set_node_ids) > 1:
dp.SetNodeIds = json.dumps(set_node_ids)
else:
dp.SetNodeId = set_node_ids[0]
except Exception as e:
logger.warning(f"Failed to store set node IDs for NodeSet: {str(e)}")
else:
logger.warning("Failed to create set nodes for NodeSet")
except Exception as e:
logger.error(f"Error creating set nodes: {str(e)}")
except Exception as e:
logger.error(f"Error processing NodeSets: {str(e)}")
async def create_set_node(
nodes: Union[List[str], str],
name: Optional[str] = None,
document_id: Optional[str] = None
) -> Tuple[Optional[List[str]], List[str]]:
"""Create individual nodes for each value in the NodeSet.
Args:
nodes: List of node values or JSON string representing node values
name: Base name for the NodeSet (optional)
document_id: ID of the document containing the NodeSet (optional)
Returns:
Tuple containing:
- List of created set node IDs (or None if creation failed)
- List of created edge IDs
"""
try:
if not nodes:
logger.warning("No nodes provided to create set nodes")
return None, []
# Get the graph engine
graph_engine = await get_graph_engine()
# Parse nodes if provided as JSON string
nodes_list = _parse_nodes_input(nodes)
if not nodes_list:
return None, []
# Base name for the set if not provided
base_name = name or f"NodeSet_{uuid.uuid4().hex[:8]}"
logger.info(f"Creating individual set nodes for {len(nodes_list)} values with base name '{base_name}'")
# Create set nodes using the appropriate strategy
return await _create_set_nodes_unified(graph_engine, nodes_list, base_name, document_id)
except Exception as e:
logger.error(f"Failed to create set nodes: {str(e)}")
return None, []
def _parse_nodes_input(nodes: Union[List[str], str]) -> Optional[List[str]]:
"""Parse the nodes input.
Args:
nodes: List of node values or JSON string representing node values
Returns:
List of node values or None if parsing failed
"""
if isinstance(nodes, str):
try:
parsed_nodes = json.loads(nodes)
logger.info(f"Parsed nodes string into list with {len(parsed_nodes)} items")
return parsed_nodes
except Exception as e:
logger.error(f"Failed to parse nodes as JSON: {str(e)}")
return None
return nodes
async def _create_set_nodes_unified(
graph_engine: Any,
nodes_list: List[str],
base_name: str,
document_id: Optional[str]
) -> Tuple[List[str], List[str]]:
"""Create set nodes using either NetworkX or generic graph engine.
Args:
graph_engine: The graph engine instance
nodes_list: List of node values
base_name: Base name for the NodeSet
document_id: ID of the document containing the NodeSet (optional)
Returns:
Tuple containing:
- List of created set node IDs
- List of created edge IDs
"""
all_set_node_ids = []
all_edge_ids = []
# Define strategies for node and edge creation based on graph engine type
if isinstance(graph_engine, NetworkXAdapter):
# NetworkX-specific strategy
async def create_node(node_value: str) -> str:
set_node_id = str(uuid.uuid4())
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
graph_engine.graph.add_node(
set_node_id,
id=set_node_id,
type="NodeSet",
name=node_name,
node_id=node_value
)
# Validate node creation
if set_node_id in graph_engine.graph.nodes():
node_props = dict(graph_engine.graph.nodes[set_node_id])
logger.info(f"Created set node for value '{node_value}': {json.dumps(node_props)}")
else:
logger.warning(f"Node {set_node_id} not found in graph after adding")
return set_node_id
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
edge_ids = []
try:
edge_id = str(uuid.uuid4())
graph_engine.graph.add_edge(
set_node_id,
node_value,
id=edge_id,
type="CONTAINS"
)
edge_ids.append(edge_id)
except Exception as e:
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
return edge_ids
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
edge_ids = []
try:
doc_to_nodeset_id = str(uuid.uuid4())
graph_engine.graph.add_edge(
document_id,
set_node_id,
id=doc_to_nodeset_id,
type="HAS_NODESET"
)
edge_ids.append(doc_to_nodeset_id)
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
except Exception as e:
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
return edge_ids
# Finalize function for NetworkX
async def finalize() -> None:
await graph_engine.save_graph_to_file(graph_engine.filename)
else:
# Generic graph engine strategy
async def create_node(node_value: str) -> str:
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
set_node_props = {
"name": node_name,
"type": "NodeSet",
"node_id": node_value
}
return await graph_engine.create_node(set_node_props)
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
edge_ids = []
try:
edge_id = await graph_engine.create_edge(
source_id=set_node_id,
target_id=node_value,
edge_type="CONTAINS"
)
edge_ids.append(edge_id)
except Exception as e:
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
return edge_ids
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
edge_ids = []
try:
doc_to_nodeset_id = await graph_engine.create_edge(
source_id=document_id,
target_id=set_node_id,
edge_type="HAS_NODESET"
)
edge_ids.append(doc_to_nodeset_id)
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
except Exception as e:
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
return edge_ids
# Finalize function for generic engine (no-op)
async def finalize() -> None:
pass
# Unified process for both engine types
for node_value in nodes_list:
try:
# Create the node
set_node_id = await create_node(node_value)
# Create edges to the value
value_edge_ids = await create_value_edge(set_node_id, node_value)
all_edge_ids.extend(value_edge_ids)
# Create edges to the document if provided
if document_id:
doc_edge_ids = await create_document_edge(document_id, set_node_id)
all_edge_ids.extend(doc_edge_ids)
all_set_node_ids.append(set_node_id)
except Exception as e:
logger.error(f"Failed to create set node for value '{node_value}': {str(e)}")
# Finalize the process
await finalize()
# Return results
if all_set_node_ids:
logger.info(f"Created {len(all_set_node_ids)} individual set nodes with values: {nodes_list}")
return all_set_node_ids, all_edge_ids
else:
logger.error("Failed to create any set nodes")
return [], []

View file

@ -1 +0,0 @@
# Node Set integration tests

View file

@ -1,180 +0,0 @@
import os
import json
import asyncio
import pytest
from uuid import uuid4
from unittest.mock import patch, AsyncMock, MagicMock
from contextlib import asynccontextmanager
import cognee
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks
from cognee.tasks.node_set import apply_node_set
from cognee.infrastructure.databases.relational import create_db_and_tables
class TestDocument(DataPoint):
"""Test document model for NodeSet testing."""
content: str
metadata: dict = {"index_fields": ["content"]}
@pytest.mark.asyncio
async def test_node_set_add_to_cognify_workflow():
"""
Test the full NodeSet workflow from add to cognify.
This test verifies that:
1. NodeSet data can be added using cognee.add
2. The NodeSet data is stored in the relational database
3. The apply_node_set task can retrieve the NodeSet and apply it to DataPoints
"""
# Clean up any existing data
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Create test data
test_content = "This is test content"
test_node_set = ["node1", "node2", "node3"]
dataset_name = f"test_dataset_{uuid4()}"
# Create database tables
await create_db_and_tables()
# Mock functions to avoid external dependencies
mock_add_data_points = AsyncMock()
mock_add_data_points.return_value = []
# Create a temporary file for the test
temp_file_path = "/tmp/test_node_set.txt"
with open(temp_file_path, "w") as f:
f.write(test_content)
try:
# Mock ingest_data to capture and verify NodeSet
original_ingest_data = cognee.tasks.ingestion.ingest_data
async def mock_ingest_data(*args, **kwargs):
# Call the original function but capture the NodeSet parameter
assert len(args) >= 3, "Expected at least 3 arguments"
assert args[1] == dataset_name, f"Expected dataset name {dataset_name}"
assert kwargs.get("NodeSet") == test_node_set or args[3] == test_node_set, (
"NodeSet not passed correctly"
)
return await original_ingest_data(*args, **kwargs)
# Replace the ingest_data function temporarily
with patch("cognee.tasks.ingestion.ingest_data", side_effect=mock_ingest_data):
# Call the add function with NodeSet
await cognee.add(temp_file_path, dataset_name, NodeSet=test_node_set)
# Create test DataPoint for apply_node_set to process
test_document = TestDocument(content=test_content)
# Test the apply_node_set task
with patch(
"cognee.tasks.node_set.apply_node_set.get_relational_engine"
) as mock_get_engine:
# Setup mock engine and session
mock_session = AsyncMock()
mock_engine = AsyncMock()
# Properly mock the async context manager
@asynccontextmanager
async def mock_get_session():
try:
yield mock_session
finally:
pass
mock_engine.get_async_session.return_value = mock_get_session()
mock_get_engine.return_value = mock_engine
# Create a mock Data object with our NodeSet
class MockData:
def __init__(self, id, node_set):
self.id = id
self.node_set = node_set
mock_data = MockData(test_document.id, json.dumps(test_node_set))
# Setup the mock result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_data]
mock_session.execute.return_value = mock_result
# Run the apply_node_set task
result = await apply_node_set([test_document])
# Verify the NodeSet was applied
assert len(result) == 1
assert result[0].NodeSet == test_node_set
# Verify the mock interactions
mock_get_engine.assert_called_once()
mock_engine.get_async_session.assert_called_once()
finally:
# Clean up the temporary file
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
# Clean up after the test
await cognee.prune.prune_data()
@pytest.mark.asyncio
async def test_node_set_in_cognify_pipeline():
"""
Test the integration of apply_node_set task in the cognify pipeline.
This test verifies that the apply_node_set task works correctly
when run as part of a pipeline with other tasks.
"""
# Create test data
test_documents = [TestDocument(content="Document 1"), TestDocument(content="Document 2")]
# Create a simple mock task that just passes data through
async def mock_task(data):
for item in data:
yield item
# Mock the apply_node_set function to verify it's called with the right data
original_apply_node_set = apply_node_set
apply_node_set_called = False
async def mock_apply_node_set(data_points):
nonlocal apply_node_set_called
apply_node_set_called = True
# Verify the input
assert len(data_points) == 2
assert all(isinstance(dp, TestDocument) for dp in data_points)
# Apply NodeSet to demonstrate it worked
for dp in data_points:
dp.NodeSet = ["test_node"]
return data_points
# Create a pipeline with our tasks
with patch("cognee.tasks.node_set.apply_node_set", side_effect=mock_apply_node_set):
pipeline = run_tasks(
tasks=[
Task(mock_task), # First task passes data through
Task(apply_node_set), # Second task applies NodeSet
],
data=test_documents,
)
# Process all results from the pipeline
results = []
async for result in pipeline:
results.extend(result)
# Verify results
assert apply_node_set_called, "apply_node_set was not called"
assert len(results) == 2
assert all(dp.NodeSet == ["test_node"] for dp in results)

View file

@ -1 +0,0 @@
# Node Set task tests

View file

@ -1,132 +0,0 @@
import pytest
import json
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import uuid4, UUID
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.tasks.node_set.apply_node_set import apply_node_set
class TestDataPoint(DataPoint):
"""Test DataPoint model for testing apply_node_set task."""
name: str
description: str
metadata: dict = {"index_fields": ["name"]}
@pytest.mark.asyncio
async def test_apply_node_set():
"""Test that apply_node_set applies NodeSet values from the relational store to DataPoint instances."""
# Create test data
dp_id1 = uuid4()
dp_id2 = uuid4()
# Create test DataPoint instances
data_points = [
TestDataPoint(id=dp_id1, name="Test 1", description="Description 1"),
TestDataPoint(id=dp_id2, name="Test 2", description="Description 2"),
]
# Create mock Data records that would be returned from the database
node_set1 = ["node1", "node2"]
node_set2 = ["node3", "node4", "node5"]
# Create a mock implementation of _process_data_points
async def mock_process_data_points(session, data_point_map):
# Apply NodeSet directly to the DataPoints
for dp_id, node_set in [(dp_id1, node_set1), (dp_id2, node_set2)]:
dp_id_str = str(dp_id)
if dp_id_str in data_point_map:
data_point_map[dp_id_str].NodeSet = node_set
# Patch the necessary functions
with (
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
patch(
"cognee.tasks.node_set.apply_node_set._process_data_points",
side_effect=mock_process_data_points,
),
):
# Call the function being tested
result = await apply_node_set(data_points)
# Verify the results
assert len(result) == 2
# Check that NodeSet values were applied correctly
assert result[0].NodeSet == node_set1
assert result[1].NodeSet == node_set2
@pytest.mark.asyncio
async def test_apply_node_set_empty_list():
"""Test apply_node_set with an empty list of DataPoints."""
result = await apply_node_set([])
assert result == []
@pytest.mark.asyncio
async def test_apply_node_set_no_matching_data():
"""Test apply_node_set when there are no matching Data records."""
# Create test data
dp_id = uuid4()
# Create test DataPoint instances
data_points = [TestDataPoint(id=dp_id, name="Test", description="Description")]
# Create a mock implementation of _process_data_points that doesn't modify any DataPoints
async def mock_process_data_points(session, data_point_map):
# Don't modify anything - simulating no matching records
pass
# Patch the necessary functions
with (
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
patch(
"cognee.tasks.node_set.apply_node_set._process_data_points",
side_effect=mock_process_data_points,
),
):
# Call the function being tested
result = await apply_node_set(data_points)
# Verify the results - NodeSet should remain None
assert len(result) == 1
assert result[0].NodeSet is None
@pytest.mark.asyncio
async def test_apply_node_set_invalid_json():
"""Test apply_node_set when there's invalid JSON in the node_set column."""
# Create test data
dp_id = uuid4()
# Create test DataPoint instances
data_points = [TestDataPoint(id=dp_id, name="Test", description="Description")]
# Create a mock implementation of _process_data_points that throws the appropriate error
async def mock_process_data_points(session, data_point_map):
# Simulate the JSONDecodeError by logging a warning
from cognee.tasks.node_set.apply_node_set import logger
logger.warning(f"Failed to parse NodeSet JSON for DataPoint {str(dp_id)}")
# Patch the necessary functions
with (
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
patch(
"cognee.tasks.node_set.apply_node_set._process_data_points",
side_effect=mock_process_data_points,
),
patch("cognee.tasks.node_set.apply_node_set.logger") as mock_logger,
):
# Call the function being tested
result = await apply_node_set(data_points)
# Verify the results - NodeSet should remain None
assert len(result) == 1
assert result[0].NodeSet is None
# Verify logger warning was called
mock_logger.warning.assert_called_once()

View file

@ -1,675 +0,0 @@
import asyncio
import os
import json
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional, Set
import webbrowser
from pathlib import Path
import cognee
from cognee.shared.logging_utils import get_logger
# Configure logger
logger = get_logger(name="enhanced_graph_visualization")
# Type aliases for clarity
NodeData = Dict[str, Any]
EdgeData = Dict[str, Any]
GraphData = Tuple[List[Tuple[Any, Dict]], List[Tuple[Any, Any, Optional[str], Dict]]]
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder to handle datetime objects."""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
class NodeSetVisualizer:
"""Class to create enhanced visualizations for NodeSet data in knowledge graphs."""
# Color mapping for different node types
NODE_COLORS = {
"Entity": "#f47710",
"EntityType": "#6510f4",
"DocumentChunk": "#801212",
"TextDocument": "#a83232", # Darker red for documents
"TextSummary": "#1077f4",
"NodeSet": "#ff00ff", # Bright magenta for NodeSet nodes
"Unknown": "#999999",
"default": "#D3D3D3",
}
# Size mapping for different node types
NODE_SIZES = {
"NodeSet": 20, # Larger size for NodeSet nodes
"TextDocument": 18, # Larger size for document nodes
"DocumentChunk": 18, # Larger size for document nodes
"TextSummary": 16, # Medium size for TextSummary nodes
"default": 13, # Default size
}
def __init__(self):
"""Initialize the visualizer."""
self.graph_engine = None
self.nodes_data = []
self.edges_data = []
self.node_count = 0
self.edge_count = 0
self.nodeset_count = 0
async def get_graph_data(self) -> bool:
"""Fetch graph data from the graph engine.
Returns:
bool: True if data was successfully retrieved, False otherwise.
"""
self.graph_engine = await cognee.infrastructure.databases.graph.get_graph_engine()
# Check if the graph exists and has nodes
self.node_count = len(self.graph_engine.graph.nodes())
self.edge_count = len(self.graph_engine.graph.edges())
logger.info(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
print(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
if self.node_count == 0:
logger.error("The graph is empty! Please run a test script first to generate data.")
print("ERROR: The graph is empty! Please run a test script first to generate data.")
return False
graph_data = await self.graph_engine.get_graph_data()
self.nodes_data, self.edges_data = graph_data
# Count NodeSets for status display
self.nodeset_count = sum(1 for _, info in self.nodes_data if info.get("type") == "NodeSet")
return True
def prepare_node_data(self) -> List[NodeData]:
"""Process raw node data to prepare for visualization.
Returns:
List[NodeData]: List of prepared node data objects.
"""
nodes_list = []
# Create a lookup for node types for faster access
node_type_lookup = {str(node_id): node_info.get("type", "Unknown")
for node_id, node_info in self.nodes_data}
for node_id, node_info in self.nodes_data:
# Create a clean copy to avoid modifying the original
processed_node = node_info.copy()
# Remove fields that cause JSON serialization issues
self._clean_node_data(processed_node)
# Add required visualization properties
processed_node["id"] = str(node_id)
node_type = processed_node.get("type", "default")
# Apply visual styling based on node type
processed_node["color"] = self.NODE_COLORS.get(node_type, self.NODE_COLORS["default"])
processed_node["size"] = self.NODE_SIZES.get(node_type, self.NODE_SIZES["default"])
# Create display names
self._format_node_display_name(processed_node, node_type)
nodes_list.append(processed_node)
return nodes_list
@staticmethod
def _clean_node_data(node: NodeData) -> None:
"""Remove fields that might cause JSON serialization issues.
Args:
node: The node data to clean
"""
# Remove non-essential fields that might cause serialization issues
for key in ["created_at", "updated_at", "raw_data_location"]:
if key in node:
del node[key]
@staticmethod
def _format_node_display_name(node: NodeData, node_type: str) -> None:
"""Format the display name for a node.
Args:
node: The node data to process
node_type: The type of the node
"""
# Set a default name if none exists
node["name"] = node.get("name", node.get("id", "Unknown"))
# Special formatting for NodeSet nodes
if node_type == "NodeSet" and "node_id" in node:
node["display_name"] = f"NodeSet: {node['node_id']}"
else:
node["display_name"] = node["name"]
# Truncate long display names
if len(node["display_name"]) > 30:
node["display_name"] = f"{node['display_name'][:27]}..."
def prepare_edge_data(self, nodes_list: List[NodeData]) -> List[EdgeData]:
"""Process raw edge data to prepare for visualization.
Args:
nodes_list: The processed node data
Returns:
List[EdgeData]: List of prepared edge data objects.
"""
links_list = []
# Create a lookup for node types for faster access
node_type_lookup = {node["id"]: node.get("type", "Unknown") for node in nodes_list}
for source, target, relation, edge_info in self.edges_data:
source_str = str(source)
target_str = str(target)
# Skip if source or target not in node_type_lookup (should not happen)
if source_str not in node_type_lookup or target_str not in node_type_lookup:
continue
# Get node types
source_type = node_type_lookup[source_str]
target_type = node_type_lookup[target_str]
# Create edge data
link_data = {
"source": source_str,
"target": target_str,
"relation": relation or "UNKNOWN"
}
# Categorize the edge for styling
link_data["connection_type"] = self._categorize_edge(source_type, target_type)
links_list.append(link_data)
return links_list
@staticmethod
def _categorize_edge(source_type: str, target_type: str) -> str:
"""Categorize an edge based on the connected node types.
Args:
source_type: The type of the source node
target_type: The type of the target node
Returns:
str: The category of the edge
"""
if source_type == "NodeSet" and target_type != "NodeSet":
return "nodeset_to_value"
elif (source_type in ["TextDocument", "DocumentChunk", "TextSummary"]) and target_type == "NodeSet":
return "document_to_nodeset"
elif target_type == "NodeSet":
return "to_nodeset"
elif source_type == "NodeSet":
return "from_nodeset"
else:
return "standard"
def generate_html(self, nodes_list: List[NodeData], links_list: List[EdgeData]) -> str:
"""Generate the HTML visualization with D3.js.
Args:
nodes_list: The processed node data
links_list: The processed edge data
Returns:
str: The HTML content for the visualization
"""
# Use embedded template directly - more reliable than file access
html_template = self._get_embedded_html_template()
# Generate the HTML content with custom JSON encoder for datetime objects
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, cls=DateTimeEncoder))
html_content = html_content.replace("{links}", json.dumps(links_list, cls=DateTimeEncoder))
html_content = html_content.replace("{node_count}", str(self.node_count))
html_content = html_content.replace("{edge_count}", str(self.edge_count))
html_content = html_content.replace("{nodeset_count}", str(self.nodeset_count))
return html_content
def save_html(self, html_content: str) -> str:
"""Save the HTML content to a file and open it in the browser.
Args:
html_content: The HTML content to save
Returns:
str: The path to the saved file
"""
# Create the output file path
output_path = Path.cwd() / "enhanced_nodeset_visualization.html"
# Write the HTML content to the file
with open(output_path, "w") as f:
f.write(html_content)
logger.info(f"Enhanced visualization saved to: {output_path}")
print(f"Enhanced visualization saved to: {output_path}")
# Open the visualization in the default web browser
file_url = f"file://{output_path}"
logger.info(f"Opening visualization in browser: {file_url}")
print(f"Opening enhanced visualization in browser: {file_url}")
webbrowser.open(file_url)
return str(output_path)
@staticmethod
def _get_embedded_html_template() -> str:
"""Get the embedded HTML template as a fallback.
Returns:
str: The HTML template
"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Cognee NodeSet Visualization</title>
<script src="https://d3js.org/d3.v5.min.js"></script>
<style>
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', Arial, sans-serif; }
svg { width: 100vw; height: 100vh; display: block; }
.links line { stroke-width: 2px; }
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
.node-label { font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
/* NodeSet specific styles */
.links line.nodeset_to_value { stroke: #ff00ff; stroke-width: 3px; stroke-dasharray: 5, 5; }
.links line.document_to_nodeset { stroke: #fc0; stroke-width: 3px; }
.links line.to_nodeset { stroke: #0cf; stroke-width: 2px; }
.links line.from_nodeset { stroke: #0fc; stroke-width: 2px; }
.nodes circle.nodeset { stroke: white; stroke-width: 2px; filter: drop-shadow(0 0 10px rgba(255,0,255,0.8)); }
.nodes circle.document { stroke: white; stroke-width: 1.5px; filter: drop-shadow(0 0 8px rgba(255,255,0,0.6)); }
.node-label.nodeset { font-size: 6px; font-weight: bold; fill: white; }
.node-label.document { font-size: 5.5px; font-weight: bold; fill: white; }
/* Legend */
.legend { position: fixed; top: 10px; left: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
.legend-item { display: flex; align-items: center; margin-bottom: 5px; }
.legend-color { width: 15px; height: 15px; margin-right: 10px; border-radius: 50%; }
.legend-label { font-size: 14px; }
/* Edge legend */
.edge-legend { position: fixed; top: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
.edge-legend-item { display: flex; align-items: center; margin-bottom: 10px; }
.edge-line { width: 30px; height: 3px; margin-right: 10px; }
.edge-label { font-size: 14px; }
/* Controls */
.controls { position: fixed; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; display: flex; gap: 10px; }
button { background: #333; color: white; border: none; padding: 5px 10px; border-radius: 4px; cursor: pointer; }
button:hover { background: #555; }
/* Status message */
.status { position: fixed; bottom: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
</style>
</head>
<body>
<svg></svg>
<!-- Node Legend -->
<div class="legend">
<h3 style="margin-top: 0;">Node Types</h3>
<div class="legend-item">
<div class="legend-color" style="background-color: #ff00ff;"></div>
<div class="legend-label">NodeSet</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #a83232;"></div>
<div class="legend-label">TextDocument</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #801212;"></div>
<div class="legend-label">DocumentChunk</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #1077f4;"></div>
<div class="legend-label">TextSummary</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #f47710;"></div>
<div class="legend-label">Entity</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #6510f4;"></div>
<div class="legend-label">EntityType</div>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #999999;"></div>
<div class="legend-label">Unknown</div>
</div>
</div>
<!-- Edge Legend -->
<div class="edge-legend">
<h3 style="margin-top: 0;">Edge Types</h3>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: #fc0;"></div>
<div class="edge-label">Document NodeSet</div>
</div>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: #ff00ff; height: 3px; background: linear-gradient(to right, #ff00ff 50%, transparent 50%); background-size: 10px 3px; background-repeat: repeat-x;"></div>
<div class="edge-label">NodeSet Value</div>
</div>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: #0cf;"></div>
<div class="edge-label">Any NodeSet</div>
</div>
<div class="edge-legend-item">
<div class="edge-line" style="background-color: rgba(255, 255, 255, 0.4);"></div>
<div class="edge-label">Standard Connection</div>
</div>
</div>
<!-- Controls -->
<div class="controls">
<button id="center-btn">Center Graph</button>
<button id="highlight-nodesets">Highlight NodeSets</button>
<button id="highlight-documents">Highlight Documents</button>
<button id="reset-highlight">Reset Highlight</button>
</div>
<!-- Status -->
<div class="status">
<div>Nodes: {node_count}</div>
<div>Edges: {edge_count}</div>
<div>NodeSets: {nodeset_count}</div>
</div>
<script>
var nodes = {nodes};
var links = {links};
var svg = d3.select("svg"),
width = window.innerWidth,
height = window.innerHeight;
var container = svg.append("g");
// Count NodeSets for status display
const nodesetCount = nodes.filter(n => n.type === "NodeSet").length;
document.querySelector('.status').innerHTML = `
<div>Nodes: ${nodes.length}</div>
<div>Edges: ${links.length}</div>
<div>NodeSets: ${nodesetCount}</div>
`;
var simulation = d3.forceSimulation(nodes)
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
.force("charge", d3.forceManyBody().strength(-300))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("x", d3.forceX().strength(0.1).x(width / 2))
.force("y", d3.forceY().strength(0.1).y(height / 2));
var link = container.append("g")
.attr("class", "links")
.selectAll("line")
.data(links)
.enter().append("line")
.attr("stroke-width", d => {
if (d.connection_type === 'document_to_nodeset' || d.connection_type === 'nodeset_to_value') {
return 3;
}
return 2;
})
.attr("stroke", d => {
switch(d.connection_type) {
case 'document_to_nodeset': return "#fc0";
case 'nodeset_to_value': return "#ff00ff";
case 'to_nodeset': return "#0cf";
case 'from_nodeset': return "#0fc";
default: return "rgba(255, 255, 255, 0.4)";
}
})
.attr("stroke-dasharray", d => d.connection_type === 'nodeset_to_value' ? "5,5" : null)
.attr("class", d => d.connection_type);
var edgeLabels = container.append("g")
.attr("class", "edge-labels")
.selectAll("text")
.data(links)
.enter().append("text")
.attr("class", "edge-label")
.text(d => d.relation);
var nodeGroup = container.append("g")
.attr("class", "nodes")
.selectAll("g")
.data(nodes)
.enter().append("g");
var node = nodeGroup.append("circle")
.attr("r", d => d.size || 13)
.attr("fill", d => d.color)
.attr("class", d => {
if (d.type === "NodeSet") return "nodeset";
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "document";
return "";
})
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("end", dragended));
nodeGroup.append("text")
.attr("class", d => {
if (d.type === "NodeSet") return "node-label nodeset";
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "node-label document";
return "node-label";
})
.attr("dy", 4)
.attr("font-size", d => {
if (d.type === "NodeSet") return "6px";
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "5.5px";
return "5px";
})
.attr("text-anchor", "middle")
.text(d => d.display_name || d.name);
node.append("title").text(d => {
// Create a formatted tooltip with node properties
let props = Object.entries(d)
.filter(([key]) => !["x", "y", "vx", "vy", "index", "fx", "fy", "color", "display_name"].includes(key))
.map(([key, value]) => `${key}: ${value}`)
.join("\\n");
return props;
});
simulation.on("tick", function() {
link.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
edgeLabels
.attr("x", d => (d.source.x + d.target.x) / 2)
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
node.attr("cx", d => d.x)
.attr("cy", d => d.y);
nodeGroup.select("text")
.attr("x", d => d.x)
.attr("y", d => d.y)
.attr("dy", 4)
.attr("text-anchor", "middle");
});
// Add zoom behavior
const zoom = d3.zoom()
.scaleExtent([0.1, 8])
.on("zoom", function() {
container.attr("transform", d3.event.transform);
});
svg.call(zoom);
// Button controls
document.getElementById("center-btn").addEventListener("click", function() {
svg.transition().duration(750).call(
zoom.transform,
d3.zoomIdentity.translate(width / 2, height / 2).scale(1)
);
});
document.getElementById("highlight-nodesets").addEventListener("click", function() {
highlightNodes("NodeSet");
});
document.getElementById("highlight-documents").addEventListener("click", function() {
highlightNodes(["TextDocument", "DocumentChunk"]);
});
document.getElementById("reset-highlight").addEventListener("click", function() {
resetHighlight();
});
function highlightNodes(typeToHighlight) {
// Dim all nodes and links
node.transition().duration(300)
.attr("opacity", 0.2);
link.transition().duration(300)
.attr("opacity", 0.2);
nodeGroup.selectAll("text").transition().duration(300)
.attr("opacity", 0.2);
// Create arrays for types if a single string is provided
const typesToHighlight = Array.isArray(typeToHighlight) ? typeToHighlight : [typeToHighlight];
// Highlight matching nodes and their connected nodes
const highlightedNodeIds = new Set();
// First, find all nodes of the target type
nodes.forEach(n => {
if (typesToHighlight.includes(n.type)) {
highlightedNodeIds.add(n.id);
}
});
// Find all connected nodes (both directions)
links.forEach(l => {
if (highlightedNodeIds.has(l.source.id || l.source)) {
highlightedNodeIds.add(l.target.id || l.target);
}
if (highlightedNodeIds.has(l.target.id || l.target)) {
highlightedNodeIds.add(l.source.id || l.source);
}
});
// Highlight the nodes
node.filter(d => highlightedNodeIds.has(d.id))
.transition().duration(300)
.attr("opacity", 1);
// Highlight the labels
nodeGroup.selectAll("text")
.filter(d => highlightedNodeIds.has(d.id))
.transition().duration(300)
.attr("opacity", 1);
// Highlight the links between highlighted nodes
link.filter(d => {
const sourceId = d.source.id || d.source;
const targetId = d.target.id || d.target;
return highlightedNodeIds.has(sourceId) && highlightedNodeIds.has(targetId);
})
.transition().duration(300)
.attr("opacity", 1);
}
function resetHighlight() {
node.transition().duration(300).attr("opacity", 1);
link.transition().duration(300).attr("opacity", 1);
nodeGroup.selectAll("text").transition().duration(300).attr("opacity", 1);
}
function dragstarted(d) {
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
}
function dragged(d) {
d.fx = d3.event.x;
d.fy = d3.event.y;
}
function dragended(d) {
if (!d3.event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}
window.addEventListener("resize", function() {
width = window.innerWidth;
height = window.innerHeight;
svg.attr("width", width).attr("height", height);
simulation.force("center", d3.forceCenter(width / 2, height / 2));
simulation.alpha(1).restart();
});
</script>
</body>
</html>
"""
async def create_visualization(self) -> Optional[str]:
"""Main method to create the visualization.
Returns:
Optional[str]: Path to the saved visualization file, or None if unsuccessful
"""
print("Creating enhanced NodeSet visualization...")
# Get graph data
if not await self.get_graph_data():
return None
try:
# Process nodes
nodes_list = self.prepare_node_data()
# Process edges
links_list = self.prepare_edge_data(nodes_list)
# Generate HTML
html_content = self.generate_html(nodes_list, links_list)
# Save to file and open in browser
return self.save_html(html_content)
except Exception as e:
logger.error(f"Error creating visualization: {str(e)}")
print(f"Error creating visualization: {str(e)}")
return None
async def main():
"""Main entry point for the script."""
visualizer = NodeSetVisualizer()
await visualizer.create_visualization()
if __name__ == "__main__":
# Run the async main function
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.close()