Removed files
This commit is contained in:
parent
1c40a5081a
commit
69c090c91d
9 changed files with 1 additions and 1605 deletions
|
|
@ -13,7 +13,7 @@ from cognee.modules.data.models import Data, Dataset
|
|||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
|
@ -25,7 +25,6 @@ from cognee.tasks.documents import (
|
|||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.tasks.node_set import apply_node_set
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
|
||||
logger = get_logger("cognify")
|
||||
|
|
@ -140,7 +139,6 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(apply_node_set, task_config={"batch_size": 10}), # Apply NodeSet values and create set nodes
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
|
|
|||
|
|
@ -1,161 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
import asyncio
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
|
||||
# Replace incorrect import with proper cognee.prune module
|
||||
# from cognee.infrastructure.data_storage import reset_cognee
|
||||
|
||||
# Remove incorrect imports and use high-level API
|
||||
# from cognee.services.cognify import cognify
|
||||
# from cognee.services.search import search
|
||||
from cognee.tasks.node_set.apply_node_set import apply_node_set
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
|
||||
# Configure logging to see detailed output
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set up httpx client for debugging
|
||||
httpx.Client(timeout=None)
|
||||
|
||||
|
||||
async def generate_test_node_ids(count: int) -> List[str]:
|
||||
"""Generate unique node IDs for testing."""
|
||||
return [str(uuid.uuid4()) for _ in range(count)]
|
||||
|
||||
|
||||
async def add_text_with_nodeset(text: str, node_set: List[str], text_id: str = None) -> str:
|
||||
"""Add text with NodeSet to cognee."""
|
||||
# Generate text ID if not provided - but note we can't directly set the ID
|
||||
if text_id is None:
|
||||
text_id = str(uuid.uuid4())
|
||||
|
||||
# Print NodeSet details
|
||||
logger.info("Adding text with NodeSet")
|
||||
logger.info(f"NodeSet for this text: {node_set}")
|
||||
|
||||
# Use high-level cognee.add() with the correct parameters
|
||||
await cognee.add(text, node_set=node_set)
|
||||
logger.info("Saved text with NodeSet to database")
|
||||
|
||||
return text_id # Note: we can't control the actual ID generated by cognee.add()
|
||||
|
||||
|
||||
async def check_data_records():
|
||||
"""Check for data records in the database to verify NodeSet storage."""
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
query = select(Data)
|
||||
result = await session.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
logger.info(f"Found {len(records)} records in the database")
|
||||
for record in records:
|
||||
logger.info(f"Record ID: {record.id}, name: {record.name}, node_set: {record.node_set}")
|
||||
# Print raw_data_location to see where the text content is stored
|
||||
logger.info(f"Raw data location: {record.raw_data_location}")
|
||||
|
||||
|
||||
async def run_simple_node_set_test():
|
||||
"""Run a simple test of NodeSet functionality."""
|
||||
try:
|
||||
# Reset cognee data using correct module
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
logger.info("Cognee data reset complete")
|
||||
|
||||
# Generate test node IDs for two NodeSets
|
||||
# nodeset1 = await generate_test_node_ids(3)
|
||||
nodeset2 = await generate_test_node_ids(3)
|
||||
nodeset1 = [
|
||||
"my_accounting_data",
|
||||
"my_set_of_manuals_about_horses",
|
||||
"my_elon_musk_secret_file",
|
||||
]
|
||||
|
||||
logger.info(f"Created test node IDs for NodeSet 1: {nodeset1}")
|
||||
logger.info(f"Created test node IDs for NodeSet 2: {nodeset2}")
|
||||
|
||||
# Add two texts with NodeSets
|
||||
text1 = "Natural Language Processing (NLP) is a field of artificial intelligence focused on enabling computers to understand and process human language. It involves techniques for language analysis, translation, and generation."
|
||||
text2 = "Artificial Intelligence (AI) systems are designed to perform tasks that typically require human intelligence. These tasks include speech recognition, decision-making, and language translation."
|
||||
|
||||
text1_id = await add_text_with_nodeset(text1, nodeset1)
|
||||
text2_id = await add_text_with_nodeset(text2, nodeset2)
|
||||
logger.info(str(text1_id))
|
||||
logger.info(str(text2_id))
|
||||
|
||||
# Verify that the NodeSets were stored correctly
|
||||
await check_data_records()
|
||||
|
||||
# Run the cognify process to create a knowledge graph
|
||||
pipeline_run_id = await cognee.cognify()
|
||||
logger.info(f"Cognify process completed with pipeline run ID: {pipeline_run_id}")
|
||||
|
||||
# Skip graph search as the NetworkXAdapter doesn't support Cypher
|
||||
logger.info("Skipping graph search as NetworkXAdapter doesn't support Cypher")
|
||||
|
||||
# Search for insights related to the added texts
|
||||
search_query = "NLP and AI"
|
||||
logger.info(f"Searching for insights with query: '{search_query}'")
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.INSIGHTS, query_text=search_query
|
||||
)
|
||||
|
||||
# Extract NodeSet information from search results
|
||||
logger.info(f"Found {len(search_results)} search results for '{search_query}':")
|
||||
for i, result in enumerate(search_results):
|
||||
logger.info(f"Result {i + 1} text: {getattr(result, 'text', 'No text')}")
|
||||
|
||||
# Check for NodeSet and SetNodeId
|
||||
node_set = getattr(result, "NodeSet", None)
|
||||
set_node_id = getattr(result, "SetNodeId", None)
|
||||
logger.info(f"Result {i + 1} - NodeSet: {node_set}, SetNodeId: {set_node_id}")
|
||||
|
||||
# Check id and type
|
||||
logger.info(
|
||||
f"Result {i + 1} - ID: {getattr(result, 'id', 'No ID')}, Type: {getattr(result, 'type', 'No type')}"
|
||||
)
|
||||
|
||||
# Check if this is a document chunk and has is_part_of property
|
||||
if hasattr(result, "is_part_of") and result.is_part_of:
|
||||
logger.info(
|
||||
f"Result {i + 1} is a document chunk with parent ID: {result.is_part_of.id}"
|
||||
)
|
||||
|
||||
# Check if the parent has a NodeSet
|
||||
parent_has_nodeset = (
|
||||
hasattr(result.is_part_of, "NodeSet") and result.is_part_of.NodeSet
|
||||
)
|
||||
logger.info(f" Parent has NodeSet: {parent_has_nodeset}")
|
||||
if parent_has_nodeset:
|
||||
logger.info(f" Parent NodeSet: {result.is_part_of.NodeSet}")
|
||||
|
||||
# Print all attributes of the result to see what's available
|
||||
logger.info(f"Result {i + 1} - All attributes: {dir(result)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in simple NodeSet test: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger(level=ERROR)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(run_simple_node_set_test())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
@ -1 +0,0 @@
|
|||
from .apply_node_set import apply_node_set
|
||||
|
|
@ -1,451 +0,0 @@
|
|||
import uuid
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Union, Tuple, Sequence, Protocol, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import or_
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
# Configure logger
|
||||
logger = get_logger(name="apply_node_set")
|
||||
|
||||
|
||||
async def apply_node_set(data: Union[DataPoint, List[DataPoint]]) -> Union[DataPoint, List[DataPoint]]:
|
||||
"""Apply NodeSet values to DataPoint objects.
|
||||
|
||||
Args:
|
||||
data: Single DataPoint or list of DataPoints to process
|
||||
|
||||
Returns:
|
||||
The processed DataPoint(s) with updated NodeSet values
|
||||
"""
|
||||
if not data:
|
||||
logger.warning("No data provided to apply NodeSet values")
|
||||
return data
|
||||
|
||||
# Convert single DataPoint to list for uniform processing
|
||||
data_points = data if isinstance(data, list) else [data]
|
||||
logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints")
|
||||
|
||||
# Process DataPoint objects to apply NodeSet values
|
||||
updated_data_points = await _process_data_points(data_points)
|
||||
|
||||
# Create set nodes for each NodeSet
|
||||
await _create_set_nodes(updated_data_points)
|
||||
|
||||
# Return data in the same format it was received
|
||||
return data_points if isinstance(data, list) else data_points[0]
|
||||
|
||||
|
||||
async def _process_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
"""Process DataPoint objects to apply NodeSet values from the database.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to process
|
||||
|
||||
Returns:
|
||||
The processed list of DataPoints with updated NodeSet values
|
||||
"""
|
||||
try:
|
||||
if not data_points:
|
||||
return []
|
||||
|
||||
# Extract IDs and collect document relationships
|
||||
data_point_ids, parent_doc_map = _collect_ids_and_relationships(data_points)
|
||||
|
||||
# Get NodeSet values from database
|
||||
nodeset_map = await _fetch_nodesets_from_database(data_point_ids, parent_doc_map)
|
||||
|
||||
# Apply NodeSet values to DataPoints
|
||||
_apply_nodesets_to_datapoints(data_points, nodeset_map, parent_doc_map)
|
||||
|
||||
return data_points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing DataPoints: {str(e)}")
|
||||
return data_points
|
||||
|
||||
|
||||
def _collect_ids_and_relationships(data_points: List[DataPoint]) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""Extract DataPoint IDs and document relationships.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of DataPoint IDs
|
||||
- Dictionary mapping DataPoint IDs to parent document IDs
|
||||
"""
|
||||
data_point_ids = []
|
||||
parent_doc_ids = []
|
||||
parent_doc_map = {}
|
||||
|
||||
# Collect all IDs to look up
|
||||
for dp in data_points:
|
||||
# Get the DataPoint ID
|
||||
if hasattr(dp, "id"):
|
||||
dp_id = str(dp.id)
|
||||
data_point_ids.append(dp_id)
|
||||
|
||||
# Check if there's a parent document to get NodeSet from
|
||||
if (hasattr(dp, "made_from") and
|
||||
hasattr(dp.made_from, "is_part_of") and
|
||||
hasattr(dp.made_from.is_part_of, "id")):
|
||||
|
||||
parent_id = str(dp.made_from.is_part_of.id)
|
||||
parent_doc_ids.append(parent_id)
|
||||
parent_doc_map[dp_id] = parent_id
|
||||
|
||||
logger.info(f"Found {len(data_point_ids)} DataPoint IDs and {len(parent_doc_ids)} parent document IDs")
|
||||
|
||||
# Combine all IDs for database lookup
|
||||
return data_point_ids + parent_doc_ids, parent_doc_map
|
||||
|
||||
|
||||
async def _fetch_nodesets_from_database(ids: List[str], parent_doc_map: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Fetch NodeSet values from the database for the given IDs.
|
||||
|
||||
Args:
|
||||
ids: List of IDs to search for
|
||||
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document IDs to their NodeSet values
|
||||
"""
|
||||
# Convert string IDs to UUIDs for database lookup
|
||||
uuid_objects = _convert_ids_to_uuids(ids)
|
||||
if not uuid_objects:
|
||||
return {}
|
||||
|
||||
# Query the database for NodeSet values
|
||||
nodeset_map = {}
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as sess:
|
||||
# Query for records matching the IDs
|
||||
query = select(Data).where(Data.id.in_(uuid_objects))
|
||||
result = await sess.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
logger.info(f"Found {len(records)} total records matching the IDs")
|
||||
|
||||
# Extract NodeSet values from records
|
||||
for record in records:
|
||||
if record.node_set is not None:
|
||||
nodeset_map[str(record.id)] = record.node_set
|
||||
|
||||
logger.info(f"Found {len(nodeset_map)} records with non-NULL NodeSet values")
|
||||
|
||||
return nodeset_map
|
||||
|
||||
|
||||
def _convert_ids_to_uuids(ids: List[str]) -> List[uuid.UUID]:
|
||||
"""Convert string IDs to UUID objects.
|
||||
|
||||
Args:
|
||||
ids: List of string IDs to convert
|
||||
|
||||
Returns:
|
||||
List of UUID objects
|
||||
"""
|
||||
uuid_objects = []
|
||||
for id_str in ids:
|
||||
try:
|
||||
uuid_objects.append(uuid.UUID(id_str))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert ID {id_str} to UUID: {str(e)}")
|
||||
|
||||
logger.info(f"Converted {len(uuid_objects)} out of {len(ids)} IDs to UUID objects")
|
||||
return uuid_objects
|
||||
|
||||
|
||||
def _apply_nodesets_to_datapoints(
|
||||
data_points: List[DataPoint],
|
||||
nodeset_map: Dict[str, Any],
|
||||
parent_doc_map: Dict[str, str]
|
||||
) -> None:
|
||||
"""Apply NodeSet values to DataPoints.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to update
|
||||
nodeset_map: Dictionary mapping document IDs to their NodeSet values
|
||||
parent_doc_map: Dictionary mapping DataPoint IDs to parent document IDs
|
||||
"""
|
||||
for dp in data_points:
|
||||
dp_id = str(dp.id)
|
||||
|
||||
# Try direct match first
|
||||
if dp_id in nodeset_map:
|
||||
nodeset = nodeset_map[dp_id]
|
||||
logger.info(f"Found NodeSet for {dp_id}: {nodeset}")
|
||||
dp.NodeSet = nodeset
|
||||
|
||||
# Then try parent document
|
||||
elif dp_id in parent_doc_map and parent_doc_map[dp_id] in nodeset_map:
|
||||
parent_id = parent_doc_map[dp_id]
|
||||
nodeset = nodeset_map[parent_id]
|
||||
logger.info(f"Found NodeSet from parent document {parent_id} for {dp_id}: {nodeset}")
|
||||
dp.NodeSet = nodeset
|
||||
|
||||
|
||||
async def _create_set_nodes(data_points: List[DataPoint]) -> None:
|
||||
"""Create set nodes for DataPoints with NodeSets.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to process
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Creating set nodes for {len(data_points)} DataPoints")
|
||||
|
||||
for dp in data_points:
|
||||
if not hasattr(dp, "NodeSet") or not dp.NodeSet:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create set nodes for the NodeSet (one per value)
|
||||
document_id = str(dp.id) if hasattr(dp, "id") else None
|
||||
set_node_ids, edge_ids = await create_set_node(dp.NodeSet, document_id=document_id)
|
||||
|
||||
if set_node_ids and len(set_node_ids) > 0:
|
||||
logger.info(f"Created {len(set_node_ids)} set nodes for NodeSet with {len(dp.NodeSet)} values")
|
||||
|
||||
# Store the set node IDs with the DataPoint if possible
|
||||
try:
|
||||
# Store as JSON string if multiple IDs, or single ID if only one
|
||||
if len(set_node_ids) > 1:
|
||||
dp.SetNodeIds = json.dumps(set_node_ids)
|
||||
else:
|
||||
dp.SetNodeId = set_node_ids[0]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store set node IDs for NodeSet: {str(e)}")
|
||||
else:
|
||||
logger.warning("Failed to create set nodes for NodeSet")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating set nodes: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing NodeSets: {str(e)}")
|
||||
|
||||
|
||||
async def create_set_node(
|
||||
nodes: Union[List[str], str],
|
||||
name: Optional[str] = None,
|
||||
document_id: Optional[str] = None
|
||||
) -> Tuple[Optional[List[str]], List[str]]:
|
||||
"""Create individual nodes for each value in the NodeSet.
|
||||
|
||||
Args:
|
||||
nodes: List of node values or JSON string representing node values
|
||||
name: Base name for the NodeSet (optional)
|
||||
document_id: ID of the document containing the NodeSet (optional)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of created set node IDs (or None if creation failed)
|
||||
- List of created edge IDs
|
||||
"""
|
||||
try:
|
||||
if not nodes:
|
||||
logger.warning("No nodes provided to create set nodes")
|
||||
return None, []
|
||||
|
||||
# Get the graph engine
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
# Parse nodes if provided as JSON string
|
||||
nodes_list = _parse_nodes_input(nodes)
|
||||
if not nodes_list:
|
||||
return None, []
|
||||
|
||||
# Base name for the set if not provided
|
||||
base_name = name or f"NodeSet_{uuid.uuid4().hex[:8]}"
|
||||
logger.info(f"Creating individual set nodes for {len(nodes_list)} values with base name '{base_name}'")
|
||||
|
||||
# Create set nodes using the appropriate strategy
|
||||
return await _create_set_nodes_unified(graph_engine, nodes_list, base_name, document_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create set nodes: {str(e)}")
|
||||
return None, []
|
||||
|
||||
|
||||
def _parse_nodes_input(nodes: Union[List[str], str]) -> Optional[List[str]]:
|
||||
"""Parse the nodes input.
|
||||
|
||||
Args:
|
||||
nodes: List of node values or JSON string representing node values
|
||||
|
||||
Returns:
|
||||
List of node values or None if parsing failed
|
||||
"""
|
||||
if isinstance(nodes, str):
|
||||
try:
|
||||
parsed_nodes = json.loads(nodes)
|
||||
logger.info(f"Parsed nodes string into list with {len(parsed_nodes)} items")
|
||||
return parsed_nodes
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse nodes as JSON: {str(e)}")
|
||||
return None
|
||||
return nodes
|
||||
|
||||
|
||||
async def _create_set_nodes_unified(
|
||||
graph_engine: Any,
|
||||
nodes_list: List[str],
|
||||
base_name: str,
|
||||
document_id: Optional[str]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Create set nodes using either NetworkX or generic graph engine.
|
||||
|
||||
Args:
|
||||
graph_engine: The graph engine instance
|
||||
nodes_list: List of node values
|
||||
base_name: Base name for the NodeSet
|
||||
document_id: ID of the document containing the NodeSet (optional)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of created set node IDs
|
||||
- List of created edge IDs
|
||||
"""
|
||||
all_set_node_ids = []
|
||||
all_edge_ids = []
|
||||
|
||||
# Define strategies for node and edge creation based on graph engine type
|
||||
if isinstance(graph_engine, NetworkXAdapter):
|
||||
# NetworkX-specific strategy
|
||||
async def create_node(node_value: str) -> str:
|
||||
set_node_id = str(uuid.uuid4())
|
||||
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
graph_engine.graph.add_node(
|
||||
set_node_id,
|
||||
id=set_node_id,
|
||||
type="NodeSet",
|
||||
name=node_name,
|
||||
node_id=node_value
|
||||
)
|
||||
|
||||
# Validate node creation
|
||||
if set_node_id in graph_engine.graph.nodes():
|
||||
node_props = dict(graph_engine.graph.nodes[set_node_id])
|
||||
logger.info(f"Created set node for value '{node_value}': {json.dumps(node_props)}")
|
||||
else:
|
||||
logger.warning(f"Node {set_node_id} not found in graph after adding")
|
||||
|
||||
return set_node_id
|
||||
|
||||
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
edge_id = str(uuid.uuid4())
|
||||
graph_engine.graph.add_edge(
|
||||
set_node_id,
|
||||
node_value,
|
||||
id=edge_id,
|
||||
type="CONTAINS"
|
||||
)
|
||||
edge_ids.append(edge_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
doc_to_nodeset_id = str(uuid.uuid4())
|
||||
graph_engine.graph.add_edge(
|
||||
document_id,
|
||||
set_node_id,
|
||||
id=doc_to_nodeset_id,
|
||||
type="HAS_NODESET"
|
||||
)
|
||||
edge_ids.append(doc_to_nodeset_id)
|
||||
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
# Finalize function for NetworkX
|
||||
async def finalize() -> None:
|
||||
await graph_engine.save_graph_to_file(graph_engine.filename)
|
||||
|
||||
else:
|
||||
# Generic graph engine strategy
|
||||
async def create_node(node_value: str) -> str:
|
||||
node_name = f"NodeSet_{node_value}_{uuid.uuid4().hex[:8]}"
|
||||
set_node_props = {
|
||||
"name": node_name,
|
||||
"type": "NodeSet",
|
||||
"node_id": node_value
|
||||
}
|
||||
return await graph_engine.create_node(set_node_props)
|
||||
|
||||
async def create_value_edge(set_node_id: str, node_value: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
edge_id = await graph_engine.create_edge(
|
||||
source_id=set_node_id,
|
||||
target_id=node_value,
|
||||
edge_type="CONTAINS"
|
||||
)
|
||||
edge_ids.append(edge_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from set node to node {node_value}: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
async def create_document_edge(document_id: str, set_node_id: str) -> List[str]:
|
||||
edge_ids = []
|
||||
try:
|
||||
doc_to_nodeset_id = await graph_engine.create_edge(
|
||||
source_id=document_id,
|
||||
target_id=set_node_id,
|
||||
edge_type="HAS_NODESET"
|
||||
)
|
||||
edge_ids.append(doc_to_nodeset_id)
|
||||
logger.info(f"Created edge from document {document_id} to NodeSet {set_node_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create edge from document to NodeSet: {str(e)}")
|
||||
return edge_ids
|
||||
|
||||
# Finalize function for generic engine (no-op)
|
||||
async def finalize() -> None:
|
||||
pass
|
||||
|
||||
# Unified process for both engine types
|
||||
for node_value in nodes_list:
|
||||
try:
|
||||
# Create the node
|
||||
set_node_id = await create_node(node_value)
|
||||
|
||||
# Create edges to the value
|
||||
value_edge_ids = await create_value_edge(set_node_id, node_value)
|
||||
all_edge_ids.extend(value_edge_ids)
|
||||
|
||||
# Create edges to the document if provided
|
||||
if document_id:
|
||||
doc_edge_ids = await create_document_edge(document_id, set_node_id)
|
||||
all_edge_ids.extend(doc_edge_ids)
|
||||
|
||||
all_set_node_ids.append(set_node_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create set node for value '{node_value}': {str(e)}")
|
||||
|
||||
# Finalize the process
|
||||
await finalize()
|
||||
|
||||
# Return results
|
||||
if all_set_node_ids:
|
||||
logger.info(f"Created {len(all_set_node_ids)} individual set nodes with values: {nodes_list}")
|
||||
return all_set_node_ids, all_edge_ids
|
||||
else:
|
||||
logger.error("Failed to create any set nodes")
|
||||
return [], []
|
||||
|
|
@ -1 +0,0 @@
|
|||
# Node Set integration tests
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.tasks.node_set import apply_node_set
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
class TestDocument(DataPoint):
|
||||
"""Test document model for NodeSet testing."""
|
||||
|
||||
content: str
|
||||
metadata: dict = {"index_fields": ["content"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_set_add_to_cognify_workflow():
|
||||
"""
|
||||
Test the full NodeSet workflow from add to cognify.
|
||||
|
||||
This test verifies that:
|
||||
1. NodeSet data can be added using cognee.add
|
||||
2. The NodeSet data is stored in the relational database
|
||||
3. The apply_node_set task can retrieve the NodeSet and apply it to DataPoints
|
||||
"""
|
||||
# Clean up any existing data
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Create test data
|
||||
test_content = "This is test content"
|
||||
test_node_set = ["node1", "node2", "node3"]
|
||||
dataset_name = f"test_dataset_{uuid4()}"
|
||||
|
||||
# Create database tables
|
||||
await create_db_and_tables()
|
||||
|
||||
# Mock functions to avoid external dependencies
|
||||
mock_add_data_points = AsyncMock()
|
||||
mock_add_data_points.return_value = []
|
||||
|
||||
# Create a temporary file for the test
|
||||
temp_file_path = "/tmp/test_node_set.txt"
|
||||
with open(temp_file_path, "w") as f:
|
||||
f.write(test_content)
|
||||
|
||||
try:
|
||||
# Mock ingest_data to capture and verify NodeSet
|
||||
original_ingest_data = cognee.tasks.ingestion.ingest_data
|
||||
|
||||
async def mock_ingest_data(*args, **kwargs):
|
||||
# Call the original function but capture the NodeSet parameter
|
||||
assert len(args) >= 3, "Expected at least 3 arguments"
|
||||
assert args[1] == dataset_name, f"Expected dataset name {dataset_name}"
|
||||
assert kwargs.get("NodeSet") == test_node_set or args[3] == test_node_set, (
|
||||
"NodeSet not passed correctly"
|
||||
)
|
||||
return await original_ingest_data(*args, **kwargs)
|
||||
|
||||
# Replace the ingest_data function temporarily
|
||||
with patch("cognee.tasks.ingestion.ingest_data", side_effect=mock_ingest_data):
|
||||
# Call the add function with NodeSet
|
||||
await cognee.add(temp_file_path, dataset_name, NodeSet=test_node_set)
|
||||
|
||||
# Create test DataPoint for apply_node_set to process
|
||||
test_document = TestDocument(content=test_content)
|
||||
|
||||
# Test the apply_node_set task
|
||||
with patch(
|
||||
"cognee.tasks.node_set.apply_node_set.get_relational_engine"
|
||||
) as mock_get_engine:
|
||||
# Setup mock engine and session
|
||||
mock_session = AsyncMock()
|
||||
mock_engine = AsyncMock()
|
||||
|
||||
# Properly mock the async context manager
|
||||
@asynccontextmanager
|
||||
async def mock_get_session():
|
||||
try:
|
||||
yield mock_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
mock_engine.get_async_session.return_value = mock_get_session()
|
||||
mock_get_engine.return_value = mock_engine
|
||||
|
||||
# Create a mock Data object with our NodeSet
|
||||
class MockData:
|
||||
def __init__(self, id, node_set):
|
||||
self.id = id
|
||||
self.node_set = node_set
|
||||
|
||||
mock_data = MockData(test_document.id, json.dumps(test_node_set))
|
||||
|
||||
# Setup the mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_data]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
# Run the apply_node_set task
|
||||
result = await apply_node_set([test_document])
|
||||
|
||||
# Verify the NodeSet was applied
|
||||
assert len(result) == 1
|
||||
assert result[0].NodeSet == test_node_set
|
||||
|
||||
# Verify the mock interactions
|
||||
mock_get_engine.assert_called_once()
|
||||
mock_engine.get_async_session.assert_called_once()
|
||||
finally:
|
||||
# Clean up the temporary file
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
# Clean up after the test
|
||||
await cognee.prune.prune_data()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_set_in_cognify_pipeline():
|
||||
"""
|
||||
Test the integration of apply_node_set task in the cognify pipeline.
|
||||
|
||||
This test verifies that the apply_node_set task works correctly
|
||||
when run as part of a pipeline with other tasks.
|
||||
"""
|
||||
# Create test data
|
||||
test_documents = [TestDocument(content="Document 1"), TestDocument(content="Document 2")]
|
||||
|
||||
# Create a simple mock task that just passes data through
|
||||
async def mock_task(data):
|
||||
for item in data:
|
||||
yield item
|
||||
|
||||
# Mock the apply_node_set function to verify it's called with the right data
|
||||
original_apply_node_set = apply_node_set
|
||||
|
||||
apply_node_set_called = False
|
||||
|
||||
async def mock_apply_node_set(data_points):
|
||||
nonlocal apply_node_set_called
|
||||
apply_node_set_called = True
|
||||
|
||||
# Verify the input
|
||||
assert len(data_points) == 2
|
||||
assert all(isinstance(dp, TestDocument) for dp in data_points)
|
||||
|
||||
# Apply NodeSet to demonstrate it worked
|
||||
for dp in data_points:
|
||||
dp.NodeSet = ["test_node"]
|
||||
|
||||
return data_points
|
||||
|
||||
# Create a pipeline with our tasks
|
||||
with patch("cognee.tasks.node_set.apply_node_set", side_effect=mock_apply_node_set):
|
||||
pipeline = run_tasks(
|
||||
tasks=[
|
||||
Task(mock_task), # First task passes data through
|
||||
Task(apply_node_set), # Second task applies NodeSet
|
||||
],
|
||||
data=test_documents,
|
||||
)
|
||||
|
||||
# Process all results from the pipeline
|
||||
results = []
|
||||
async for result in pipeline:
|
||||
results.extend(result)
|
||||
|
||||
# Verify results
|
||||
assert apply_node_set_called, "apply_node_set was not called"
|
||||
assert len(results) == 2
|
||||
assert all(dp.NodeSet == ["test_node"] for dp in results)
|
||||
|
|
@ -1 +0,0 @@
|
|||
# Node Set task tests
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.tasks.node_set.apply_node_set import apply_node_set
|
||||
|
||||
|
||||
class TestDataPoint(DataPoint):
|
||||
"""Test DataPoint model for testing apply_node_set task."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set():
|
||||
"""Test that apply_node_set applies NodeSet values from the relational store to DataPoint instances."""
|
||||
# Create test data
|
||||
dp_id1 = uuid4()
|
||||
dp_id2 = uuid4()
|
||||
|
||||
# Create test DataPoint instances
|
||||
data_points = [
|
||||
TestDataPoint(id=dp_id1, name="Test 1", description="Description 1"),
|
||||
TestDataPoint(id=dp_id2, name="Test 2", description="Description 2"),
|
||||
]
|
||||
|
||||
# Create mock Data records that would be returned from the database
|
||||
node_set1 = ["node1", "node2"]
|
||||
node_set2 = ["node3", "node4", "node5"]
|
||||
|
||||
# Create a mock implementation of _process_data_points
|
||||
async def mock_process_data_points(session, data_point_map):
|
||||
# Apply NodeSet directly to the DataPoints
|
||||
for dp_id, node_set in [(dp_id1, node_set1), (dp_id2, node_set2)]:
|
||||
dp_id_str = str(dp_id)
|
||||
if dp_id_str in data_point_map:
|
||||
data_point_map[dp_id_str].NodeSet = node_set
|
||||
|
||||
# Patch the necessary functions
|
||||
with (
|
||||
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
|
||||
patch(
|
||||
"cognee.tasks.node_set.apply_node_set._process_data_points",
|
||||
side_effect=mock_process_data_points,
|
||||
),
|
||||
):
|
||||
# Call the function being tested
|
||||
result = await apply_node_set(data_points)
|
||||
|
||||
# Verify the results
|
||||
assert len(result) == 2
|
||||
|
||||
# Check that NodeSet values were applied correctly
|
||||
assert result[0].NodeSet == node_set1
|
||||
assert result[1].NodeSet == node_set2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set_empty_list():
|
||||
"""Test apply_node_set with an empty list of DataPoints."""
|
||||
result = await apply_node_set([])
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set_no_matching_data():
|
||||
"""Test apply_node_set when there are no matching Data records."""
|
||||
# Create test data
|
||||
dp_id = uuid4()
|
||||
|
||||
# Create test DataPoint instances
|
||||
data_points = [TestDataPoint(id=dp_id, name="Test", description="Description")]
|
||||
|
||||
# Create a mock implementation of _process_data_points that doesn't modify any DataPoints
|
||||
async def mock_process_data_points(session, data_point_map):
|
||||
# Don't modify anything - simulating no matching records
|
||||
pass
|
||||
|
||||
# Patch the necessary functions
|
||||
with (
|
||||
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
|
||||
patch(
|
||||
"cognee.tasks.node_set.apply_node_set._process_data_points",
|
||||
side_effect=mock_process_data_points,
|
||||
),
|
||||
):
|
||||
# Call the function being tested
|
||||
result = await apply_node_set(data_points)
|
||||
|
||||
# Verify the results - NodeSet should remain None
|
||||
assert len(result) == 1
|
||||
assert result[0].NodeSet is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set_invalid_json():
|
||||
"""Test apply_node_set when there's invalid JSON in the node_set column."""
|
||||
# Create test data
|
||||
dp_id = uuid4()
|
||||
|
||||
# Create test DataPoint instances
|
||||
data_points = [TestDataPoint(id=dp_id, name="Test", description="Description")]
|
||||
|
||||
# Create a mock implementation of _process_data_points that throws the appropriate error
|
||||
async def mock_process_data_points(session, data_point_map):
|
||||
# Simulate the JSONDecodeError by logging a warning
|
||||
from cognee.tasks.node_set.apply_node_set import logger
|
||||
|
||||
logger.warning(f"Failed to parse NodeSet JSON for DataPoint {str(dp_id)}")
|
||||
|
||||
# Patch the necessary functions
|
||||
with (
|
||||
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
|
||||
patch(
|
||||
"cognee.tasks.node_set.apply_node_set._process_data_points",
|
||||
side_effect=mock_process_data_points,
|
||||
),
|
||||
patch("cognee.tasks.node_set.apply_node_set.logger") as mock_logger,
|
||||
):
|
||||
# Call the function being tested
|
||||
result = await apply_node_set(data_points)
|
||||
|
||||
# Verify the results - NodeSet should remain None
|
||||
assert len(result) == 1
|
||||
assert result[0].NodeSet is None
|
||||
|
||||
# Verify logger warning was called
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
|
@ -1,675 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Tuple, Any, Optional, Set
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
# Configure logger
|
||||
logger = get_logger(name="enhanced_graph_visualization")
|
||||
|
||||
# Type aliases for clarity
|
||||
NodeData = Dict[str, Any]
|
||||
EdgeData = Dict[str, Any]
|
||||
GraphData = Tuple[List[Tuple[Any, Dict]], List[Tuple[Any, Any, Optional[str], Dict]]]
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder to handle datetime objects."""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class NodeSetVisualizer:
|
||||
"""Class to create enhanced visualizations for NodeSet data in knowledge graphs."""
|
||||
|
||||
# Color mapping for different node types
|
||||
NODE_COLORS = {
|
||||
"Entity": "#f47710",
|
||||
"EntityType": "#6510f4",
|
||||
"DocumentChunk": "#801212",
|
||||
"TextDocument": "#a83232", # Darker red for documents
|
||||
"TextSummary": "#1077f4",
|
||||
"NodeSet": "#ff00ff", # Bright magenta for NodeSet nodes
|
||||
"Unknown": "#999999",
|
||||
"default": "#D3D3D3",
|
||||
}
|
||||
|
||||
# Size mapping for different node types
|
||||
NODE_SIZES = {
|
||||
"NodeSet": 20, # Larger size for NodeSet nodes
|
||||
"TextDocument": 18, # Larger size for document nodes
|
||||
"DocumentChunk": 18, # Larger size for document nodes
|
||||
"TextSummary": 16, # Medium size for TextSummary nodes
|
||||
"default": 13, # Default size
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the visualizer."""
|
||||
self.graph_engine = None
|
||||
self.nodes_data = []
|
||||
self.edges_data = []
|
||||
self.node_count = 0
|
||||
self.edge_count = 0
|
||||
self.nodeset_count = 0
|
||||
|
||||
async def get_graph_data(self) -> bool:
|
||||
"""Fetch graph data from the graph engine.
|
||||
|
||||
Returns:
|
||||
bool: True if data was successfully retrieved, False otherwise.
|
||||
"""
|
||||
self.graph_engine = await cognee.infrastructure.databases.graph.get_graph_engine()
|
||||
|
||||
# Check if the graph exists and has nodes
|
||||
self.node_count = len(self.graph_engine.graph.nodes())
|
||||
self.edge_count = len(self.graph_engine.graph.edges())
|
||||
logger.info(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
|
||||
print(f"Graph contains {self.node_count} nodes and {self.edge_count} edges")
|
||||
|
||||
if self.node_count == 0:
|
||||
logger.error("The graph is empty! Please run a test script first to generate data.")
|
||||
print("ERROR: The graph is empty! Please run a test script first to generate data.")
|
||||
return False
|
||||
|
||||
graph_data = await self.graph_engine.get_graph_data()
|
||||
self.nodes_data, self.edges_data = graph_data
|
||||
|
||||
# Count NodeSets for status display
|
||||
self.nodeset_count = sum(1 for _, info in self.nodes_data if info.get("type") == "NodeSet")
|
||||
|
||||
return True
|
||||
|
||||
def prepare_node_data(self) -> List[NodeData]:
|
||||
"""Process raw node data to prepare for visualization.
|
||||
|
||||
Returns:
|
||||
List[NodeData]: List of prepared node data objects.
|
||||
"""
|
||||
nodes_list = []
|
||||
|
||||
# Create a lookup for node types for faster access
|
||||
node_type_lookup = {str(node_id): node_info.get("type", "Unknown")
|
||||
for node_id, node_info in self.nodes_data}
|
||||
|
||||
for node_id, node_info in self.nodes_data:
|
||||
# Create a clean copy to avoid modifying the original
|
||||
processed_node = node_info.copy()
|
||||
|
||||
# Remove fields that cause JSON serialization issues
|
||||
self._clean_node_data(processed_node)
|
||||
|
||||
# Add required visualization properties
|
||||
processed_node["id"] = str(node_id)
|
||||
node_type = processed_node.get("type", "default")
|
||||
|
||||
# Apply visual styling based on node type
|
||||
processed_node["color"] = self.NODE_COLORS.get(node_type, self.NODE_COLORS["default"])
|
||||
processed_node["size"] = self.NODE_SIZES.get(node_type, self.NODE_SIZES["default"])
|
||||
|
||||
# Create display names
|
||||
self._format_node_display_name(processed_node, node_type)
|
||||
|
||||
nodes_list.append(processed_node)
|
||||
|
||||
return nodes_list
|
||||
|
||||
@staticmethod
|
||||
def _clean_node_data(node: NodeData) -> None:
|
||||
"""Remove fields that might cause JSON serialization issues.
|
||||
|
||||
Args:
|
||||
node: The node data to clean
|
||||
"""
|
||||
# Remove non-essential fields that might cause serialization issues
|
||||
for key in ["created_at", "updated_at", "raw_data_location"]:
|
||||
if key in node:
|
||||
del node[key]
|
||||
|
||||
@staticmethod
|
||||
def _format_node_display_name(node: NodeData, node_type: str) -> None:
|
||||
"""Format the display name for a node.
|
||||
|
||||
Args:
|
||||
node: The node data to process
|
||||
node_type: The type of the node
|
||||
"""
|
||||
# Set a default name if none exists
|
||||
node["name"] = node.get("name", node.get("id", "Unknown"))
|
||||
|
||||
# Special formatting for NodeSet nodes
|
||||
if node_type == "NodeSet" and "node_id" in node:
|
||||
node["display_name"] = f"NodeSet: {node['node_id']}"
|
||||
else:
|
||||
node["display_name"] = node["name"]
|
||||
|
||||
# Truncate long display names
|
||||
if len(node["display_name"]) > 30:
|
||||
node["display_name"] = f"{node['display_name'][:27]}..."
|
||||
|
||||
def prepare_edge_data(self, nodes_list: List[NodeData]) -> List[EdgeData]:
|
||||
"""Process raw edge data to prepare for visualization.
|
||||
|
||||
Args:
|
||||
nodes_list: The processed node data
|
||||
|
||||
Returns:
|
||||
List[EdgeData]: List of prepared edge data objects.
|
||||
"""
|
||||
links_list = []
|
||||
|
||||
# Create a lookup for node types for faster access
|
||||
node_type_lookup = {node["id"]: node.get("type", "Unknown") for node in nodes_list}
|
||||
|
||||
for source, target, relation, edge_info in self.edges_data:
|
||||
source_str = str(source)
|
||||
target_str = str(target)
|
||||
|
||||
# Skip if source or target not in node_type_lookup (should not happen)
|
||||
if source_str not in node_type_lookup or target_str not in node_type_lookup:
|
||||
continue
|
||||
|
||||
# Get node types
|
||||
source_type = node_type_lookup[source_str]
|
||||
target_type = node_type_lookup[target_str]
|
||||
|
||||
# Create edge data
|
||||
link_data = {
|
||||
"source": source_str,
|
||||
"target": target_str,
|
||||
"relation": relation or "UNKNOWN"
|
||||
}
|
||||
|
||||
# Categorize the edge for styling
|
||||
link_data["connection_type"] = self._categorize_edge(source_type, target_type)
|
||||
|
||||
links_list.append(link_data)
|
||||
|
||||
return links_list
|
||||
|
||||
@staticmethod
|
||||
def _categorize_edge(source_type: str, target_type: str) -> str:
|
||||
"""Categorize an edge based on the connected node types.
|
||||
|
||||
Args:
|
||||
source_type: The type of the source node
|
||||
target_type: The type of the target node
|
||||
|
||||
Returns:
|
||||
str: The category of the edge
|
||||
"""
|
||||
if source_type == "NodeSet" and target_type != "NodeSet":
|
||||
return "nodeset_to_value"
|
||||
elif (source_type in ["TextDocument", "DocumentChunk", "TextSummary"]) and target_type == "NodeSet":
|
||||
return "document_to_nodeset"
|
||||
elif target_type == "NodeSet":
|
||||
return "to_nodeset"
|
||||
elif source_type == "NodeSet":
|
||||
return "from_nodeset"
|
||||
else:
|
||||
return "standard"
|
||||
|
||||
def generate_html(self, nodes_list: List[NodeData], links_list: List[EdgeData]) -> str:
|
||||
"""Generate the HTML visualization with D3.js.
|
||||
|
||||
Args:
|
||||
nodes_list: The processed node data
|
||||
links_list: The processed edge data
|
||||
|
||||
Returns:
|
||||
str: The HTML content for the visualization
|
||||
"""
|
||||
# Use embedded template directly - more reliable than file access
|
||||
html_template = self._get_embedded_html_template()
|
||||
|
||||
# Generate the HTML content with custom JSON encoder for datetime objects
|
||||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list, cls=DateTimeEncoder))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list, cls=DateTimeEncoder))
|
||||
html_content = html_content.replace("{node_count}", str(self.node_count))
|
||||
html_content = html_content.replace("{edge_count}", str(self.edge_count))
|
||||
html_content = html_content.replace("{nodeset_count}", str(self.nodeset_count))
|
||||
|
||||
return html_content
|
||||
|
||||
def save_html(self, html_content: str) -> str:
|
||||
"""Save the HTML content to a file and open it in the browser.
|
||||
|
||||
Args:
|
||||
html_content: The HTML content to save
|
||||
|
||||
Returns:
|
||||
str: The path to the saved file
|
||||
"""
|
||||
# Create the output file path
|
||||
output_path = Path.cwd() / "enhanced_nodeset_visualization.html"
|
||||
|
||||
# Write the HTML content to the file
|
||||
with open(output_path, "w") as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f"Enhanced visualization saved to: {output_path}")
|
||||
print(f"Enhanced visualization saved to: {output_path}")
|
||||
|
||||
# Open the visualization in the default web browser
|
||||
file_url = f"file://{output_path}"
|
||||
logger.info(f"Opening visualization in browser: {file_url}")
|
||||
print(f"Opening enhanced visualization in browser: {file_url}")
|
||||
webbrowser.open(file_url)
|
||||
|
||||
return str(output_path)
|
||||
|
||||
@staticmethod
|
||||
def _get_embedded_html_template() -> str:
|
||||
"""Get the embedded HTML template as a fallback.
|
||||
|
||||
Returns:
|
||||
str: The HTML template
|
||||
"""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Cognee NodeSet Visualization</title>
|
||||
<script src="https://d3js.org/d3.v5.min.js"></script>
|
||||
<style>
|
||||
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', Arial, sans-serif; }
|
||||
|
||||
svg { width: 100vw; height: 100vh; display: block; }
|
||||
.links line { stroke-width: 2px; }
|
||||
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
|
||||
.node-label { font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
|
||||
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', Arial, sans-serif; pointer-events: none; }
|
||||
|
||||
/* NodeSet specific styles */
|
||||
.links line.nodeset_to_value { stroke: #ff00ff; stroke-width: 3px; stroke-dasharray: 5, 5; }
|
||||
.links line.document_to_nodeset { stroke: #fc0; stroke-width: 3px; }
|
||||
.links line.to_nodeset { stroke: #0cf; stroke-width: 2px; }
|
||||
.links line.from_nodeset { stroke: #0fc; stroke-width: 2px; }
|
||||
|
||||
.nodes circle.nodeset { stroke: white; stroke-width: 2px; filter: drop-shadow(0 0 10px rgba(255,0,255,0.8)); }
|
||||
.nodes circle.document { stroke: white; stroke-width: 1.5px; filter: drop-shadow(0 0 8px rgba(255,255,0,0.6)); }
|
||||
.node-label.nodeset { font-size: 6px; font-weight: bold; fill: white; }
|
||||
.node-label.document { font-size: 5.5px; font-weight: bold; fill: white; }
|
||||
|
||||
/* Legend */
|
||||
.legend { position: fixed; top: 10px; left: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
|
||||
.legend-item { display: flex; align-items: center; margin-bottom: 5px; }
|
||||
.legend-color { width: 15px; height: 15px; margin-right: 10px; border-radius: 50%; }
|
||||
.legend-label { font-size: 14px; }
|
||||
|
||||
/* Edge legend */
|
||||
.edge-legend { position: fixed; top: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
|
||||
.edge-legend-item { display: flex; align-items: center; margin-bottom: 10px; }
|
||||
.edge-line { width: 30px; height: 3px; margin-right: 10px; }
|
||||
.edge-label { font-size: 14px; }
|
||||
|
||||
/* Controls */
|
||||
.controls { position: fixed; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; display: flex; gap: 10px; }
|
||||
button { background: #333; color: white; border: none; padding: 5px 10px; border-radius: 4px; cursor: pointer; }
|
||||
button:hover { background: #555; }
|
||||
|
||||
/* Status message */
|
||||
.status { position: fixed; bottom: 10px; right: 10px; background: rgba(0,0,0,0.7); padding: 10px; border-radius: 5px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<svg></svg>
|
||||
|
||||
<!-- Node Legend -->
|
||||
<div class="legend">
|
||||
<h3 style="margin-top: 0;">Node Types</h3>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #ff00ff;"></div>
|
||||
<div class="legend-label">NodeSet</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #a83232;"></div>
|
||||
<div class="legend-label">TextDocument</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #801212;"></div>
|
||||
<div class="legend-label">DocumentChunk</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #1077f4;"></div>
|
||||
<div class="legend-label">TextSummary</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #f47710;"></div>
|
||||
<div class="legend-label">Entity</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #6510f4;"></div>
|
||||
<div class="legend-label">EntityType</div>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background-color: #999999;"></div>
|
||||
<div class="legend-label">Unknown</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Edge Legend -->
|
||||
<div class="edge-legend">
|
||||
<h3 style="margin-top: 0;">Edge Types</h3>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: #fc0;"></div>
|
||||
<div class="edge-label">Document → NodeSet</div>
|
||||
</div>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: #ff00ff; height: 3px; background: linear-gradient(to right, #ff00ff 50%, transparent 50%); background-size: 10px 3px; background-repeat: repeat-x;"></div>
|
||||
<div class="edge-label">NodeSet → Value</div>
|
||||
</div>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: #0cf;"></div>
|
||||
<div class="edge-label">Any → NodeSet</div>
|
||||
</div>
|
||||
<div class="edge-legend-item">
|
||||
<div class="edge-line" style="background-color: rgba(255, 255, 255, 0.4);"></div>
|
||||
<div class="edge-label">Standard Connection</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Controls -->
|
||||
<div class="controls">
|
||||
<button id="center-btn">Center Graph</button>
|
||||
<button id="highlight-nodesets">Highlight NodeSets</button>
|
||||
<button id="highlight-documents">Highlight Documents</button>
|
||||
<button id="reset-highlight">Reset Highlight</button>
|
||||
</div>
|
||||
|
||||
<!-- Status -->
|
||||
<div class="status">
|
||||
<div>Nodes: {node_count}</div>
|
||||
<div>Edges: {edge_count}</div>
|
||||
<div>NodeSets: {nodeset_count}</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
var nodes = {nodes};
|
||||
var links = {links};
|
||||
|
||||
var svg = d3.select("svg"),
|
||||
width = window.innerWidth,
|
||||
height = window.innerHeight;
|
||||
|
||||
var container = svg.append("g");
|
||||
|
||||
// Count NodeSets for status display
|
||||
const nodesetCount = nodes.filter(n => n.type === "NodeSet").length;
|
||||
document.querySelector('.status').innerHTML = `
|
||||
<div>Nodes: ${nodes.length}</div>
|
||||
<div>Edges: ${links.length}</div>
|
||||
<div>NodeSets: ${nodesetCount}</div>
|
||||
`;
|
||||
|
||||
var simulation = d3.forceSimulation(nodes)
|
||||
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
|
||||
.force("charge", d3.forceManyBody().strength(-300))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2))
|
||||
.force("x", d3.forceX().strength(0.1).x(width / 2))
|
||||
.force("y", d3.forceY().strength(0.1).y(height / 2));
|
||||
|
||||
var link = container.append("g")
|
||||
.attr("class", "links")
|
||||
.selectAll("line")
|
||||
.data(links)
|
||||
.enter().append("line")
|
||||
.attr("stroke-width", d => {
|
||||
if (d.connection_type === 'document_to_nodeset' || d.connection_type === 'nodeset_to_value') {
|
||||
return 3;
|
||||
}
|
||||
return 2;
|
||||
})
|
||||
.attr("stroke", d => {
|
||||
switch(d.connection_type) {
|
||||
case 'document_to_nodeset': return "#fc0";
|
||||
case 'nodeset_to_value': return "#ff00ff";
|
||||
case 'to_nodeset': return "#0cf";
|
||||
case 'from_nodeset': return "#0fc";
|
||||
default: return "rgba(255, 255, 255, 0.4)";
|
||||
}
|
||||
})
|
||||
.attr("stroke-dasharray", d => d.connection_type === 'nodeset_to_value' ? "5,5" : null)
|
||||
.attr("class", d => d.connection_type);
|
||||
|
||||
var edgeLabels = container.append("g")
|
||||
.attr("class", "edge-labels")
|
||||
.selectAll("text")
|
||||
.data(links)
|
||||
.enter().append("text")
|
||||
.attr("class", "edge-label")
|
||||
.text(d => d.relation);
|
||||
|
||||
var nodeGroup = container.append("g")
|
||||
.attr("class", "nodes")
|
||||
.selectAll("g")
|
||||
.data(nodes)
|
||||
.enter().append("g");
|
||||
|
||||
var node = nodeGroup.append("circle")
|
||||
.attr("r", d => d.size || 13)
|
||||
.attr("fill", d => d.color)
|
||||
.attr("class", d => {
|
||||
if (d.type === "NodeSet") return "nodeset";
|
||||
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "document";
|
||||
return "";
|
||||
})
|
||||
.call(d3.drag()
|
||||
.on("start", dragstarted)
|
||||
.on("drag", dragged)
|
||||
.on("end", dragended));
|
||||
|
||||
nodeGroup.append("text")
|
||||
.attr("class", d => {
|
||||
if (d.type === "NodeSet") return "node-label nodeset";
|
||||
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "node-label document";
|
||||
return "node-label";
|
||||
})
|
||||
.attr("dy", 4)
|
||||
.attr("font-size", d => {
|
||||
if (d.type === "NodeSet") return "6px";
|
||||
if (d.type === "TextDocument" || d.type === "DocumentChunk") return "5.5px";
|
||||
return "5px";
|
||||
})
|
||||
.attr("text-anchor", "middle")
|
||||
.text(d => d.display_name || d.name);
|
||||
|
||||
node.append("title").text(d => {
|
||||
// Create a formatted tooltip with node properties
|
||||
let props = Object.entries(d)
|
||||
.filter(([key]) => !["x", "y", "vx", "vy", "index", "fx", "fy", "color", "display_name"].includes(key))
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join("\\n");
|
||||
return props;
|
||||
});
|
||||
|
||||
simulation.on("tick", function() {
|
||||
link.attr("x1", d => d.source.x)
|
||||
.attr("y1", d => d.source.y)
|
||||
.attr("x2", d => d.target.x)
|
||||
.attr("y2", d => d.target.y);
|
||||
|
||||
edgeLabels
|
||||
.attr("x", d => (d.source.x + d.target.x) / 2)
|
||||
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
|
||||
|
||||
node.attr("cx", d => d.x)
|
||||
.attr("cy", d => d.y);
|
||||
|
||||
nodeGroup.select("text")
|
||||
.attr("x", d => d.x)
|
||||
.attr("y", d => d.y)
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle");
|
||||
});
|
||||
|
||||
// Add zoom behavior
|
||||
const zoom = d3.zoom()
|
||||
.scaleExtent([0.1, 8])
|
||||
.on("zoom", function() {
|
||||
container.attr("transform", d3.event.transform);
|
||||
});
|
||||
|
||||
svg.call(zoom);
|
||||
|
||||
// Button controls
|
||||
document.getElementById("center-btn").addEventListener("click", function() {
|
||||
svg.transition().duration(750).call(
|
||||
zoom.transform,
|
||||
d3.zoomIdentity.translate(width / 2, height / 2).scale(1)
|
||||
);
|
||||
});
|
||||
|
||||
document.getElementById("highlight-nodesets").addEventListener("click", function() {
|
||||
highlightNodes("NodeSet");
|
||||
});
|
||||
|
||||
document.getElementById("highlight-documents").addEventListener("click", function() {
|
||||
highlightNodes(["TextDocument", "DocumentChunk"]);
|
||||
});
|
||||
|
||||
document.getElementById("reset-highlight").addEventListener("click", function() {
|
||||
resetHighlight();
|
||||
});
|
||||
|
||||
function highlightNodes(typeToHighlight) {
|
||||
// Dim all nodes and links
|
||||
node.transition().duration(300)
|
||||
.attr("opacity", 0.2);
|
||||
link.transition().duration(300)
|
||||
.attr("opacity", 0.2);
|
||||
nodeGroup.selectAll("text").transition().duration(300)
|
||||
.attr("opacity", 0.2);
|
||||
|
||||
// Create arrays for types if a single string is provided
|
||||
const typesToHighlight = Array.isArray(typeToHighlight) ? typeToHighlight : [typeToHighlight];
|
||||
|
||||
// Highlight matching nodes and their connected nodes
|
||||
const highlightedNodeIds = new Set();
|
||||
|
||||
// First, find all nodes of the target type
|
||||
nodes.forEach(n => {
|
||||
if (typesToHighlight.includes(n.type)) {
|
||||
highlightedNodeIds.add(n.id);
|
||||
}
|
||||
});
|
||||
|
||||
// Find all connected nodes (both directions)
|
||||
links.forEach(l => {
|
||||
if (highlightedNodeIds.has(l.source.id || l.source)) {
|
||||
highlightedNodeIds.add(l.target.id || l.target);
|
||||
}
|
||||
if (highlightedNodeIds.has(l.target.id || l.target)) {
|
||||
highlightedNodeIds.add(l.source.id || l.source);
|
||||
}
|
||||
});
|
||||
|
||||
// Highlight the nodes
|
||||
node.filter(d => highlightedNodeIds.has(d.id))
|
||||
.transition().duration(300)
|
||||
.attr("opacity", 1);
|
||||
|
||||
// Highlight the labels
|
||||
nodeGroup.selectAll("text")
|
||||
.filter(d => highlightedNodeIds.has(d.id))
|
||||
.transition().duration(300)
|
||||
.attr("opacity", 1);
|
||||
|
||||
// Highlight the links between highlighted nodes
|
||||
link.filter(d => {
|
||||
const sourceId = d.source.id || d.source;
|
||||
const targetId = d.target.id || d.target;
|
||||
return highlightedNodeIds.has(sourceId) && highlightedNodeIds.has(targetId);
|
||||
})
|
||||
.transition().duration(300)
|
||||
.attr("opacity", 1);
|
||||
}
|
||||
|
||||
function resetHighlight() {
|
||||
node.transition().duration(300).attr("opacity", 1);
|
||||
link.transition().duration(300).attr("opacity", 1);
|
||||
nodeGroup.selectAll("text").transition().duration(300).attr("opacity", 1);
|
||||
}
|
||||
|
||||
function dragstarted(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x;
|
||||
d.fy = d.y;
|
||||
}
|
||||
|
||||
function dragged(d) {
|
||||
d.fx = d3.event.x;
|
||||
d.fy = d3.event.y;
|
||||
}
|
||||
|
||||
function dragended(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0);
|
||||
d.fx = null;
|
||||
d.fy = null;
|
||||
}
|
||||
|
||||
window.addEventListener("resize", function() {
|
||||
width = window.innerWidth;
|
||||
height = window.innerHeight;
|
||||
svg.attr("width", width).attr("height", height);
|
||||
simulation.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
simulation.alpha(1).restart();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
async def create_visualization(self) -> Optional[str]:
|
||||
"""Main method to create the visualization.
|
||||
|
||||
Returns:
|
||||
Optional[str]: Path to the saved visualization file, or None if unsuccessful
|
||||
"""
|
||||
print("Creating enhanced NodeSet visualization...")
|
||||
|
||||
# Get graph data
|
||||
if not await self.get_graph_data():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Process nodes
|
||||
nodes_list = self.prepare_node_data()
|
||||
|
||||
# Process edges
|
||||
links_list = self.prepare_edge_data(nodes_list)
|
||||
|
||||
# Generate HTML
|
||||
html_content = self.generate_html(nodes_list, links_list)
|
||||
|
||||
# Save to file and open in browser
|
||||
return self.save_html(html_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating visualization: {str(e)}")
|
||||
print(f"Error creating visualization: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the script."""
|
||||
visualizer = NodeSetVisualizer()
|
||||
await visualizer.create_visualization()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the async main function
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
Loading…
Add table
Reference in a new issue