added fixes
This commit is contained in:
parent
ec68a8cd2d
commit
e7a14b9c60
11 changed files with 1171 additions and 8 deletions
67
cognee/.env.template
Normal file
67
cognee/.env.template
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
# Cognee Configuration file
|
||||
# Copy this file to .env and fill in the values
|
||||
|
||||
# Default User Configuration
|
||||
DEFAULT_USER_EMAIL=""
|
||||
DEFAULT_USER_PASSWORD=""
|
||||
|
||||
# Vector Database configuration
|
||||
VECTOR_DB_TYPE=ChromaDB
|
||||
VECTOR_DB_PATH=./chromadb
|
||||
CHROMA_PERSIST_DIRECTORY=./chroma_db
|
||||
|
||||
# Graph Database configuration
|
||||
GRAPH_DB_TYPE=NetworkX
|
||||
GRAPH_DB_PATH=./.data/graph.json
|
||||
|
||||
# Content Storage configuration
|
||||
CONTENT_STORAGE_TYPE=FileSystem
|
||||
CONTENT_STORAGE_PATH=./storage
|
||||
|
||||
# Application settings
|
||||
APP_NAME=Cognee
|
||||
APP_ENVIRONMENT=production
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# LLM configuration
|
||||
LLM_PROVIDER=openai
|
||||
LLM_MODEL=gpt-4
|
||||
LLM_API_KEY=sk-...
|
||||
LLM_ENDPOINT=
|
||||
LLM_API_VERSION=
|
||||
LLM_TEMPERATURE=0.0
|
||||
LLM_STREAMING=false
|
||||
LLM_MAX_TOKENS=16384
|
||||
|
||||
# Rate limiting configuration
|
||||
LLM_RATE_LIMIT_ENABLED=false
|
||||
LLM_RATE_LIMIT=60/minute
|
||||
LLM_RATE_LIMIT_STRATEGY=moving-window
|
||||
LLM_RATE_LIMIT_STORAGE=memory
|
||||
# For Redis storage
|
||||
# LLM_RATE_LIMIT_REDIS_URL=redis://localhost:6379/0
|
||||
# For Memcached storage
|
||||
# LLM_RATE_LIMIT_MEMCACHED_HOST=localhost
|
||||
# LLM_RATE_LIMIT_MEMCACHED_PORT=11211
|
||||
|
||||
# Embedding configuration
|
||||
EMBEDDING_PROVIDER=openai
|
||||
EMBEDDING_MODEL=text-embedding-3-small
|
||||
EMBEDDING_DIMENSIONS=1536
|
||||
EMBEDDING_API_KEY=sk-...
|
||||
|
||||
# MongoDB configuration (optional)
|
||||
# MONGODB_URI=mongodb://localhost:27017
|
||||
# MONGODB_DB_NAME=cognee
|
||||
|
||||
# Metrics configuration (optional)
|
||||
METRICS_ENABLED=false
|
||||
METRICS_PORT=9090
|
||||
|
||||
# Monitoring configuration
|
||||
MONITORING_TOOL=None
|
||||
# For Langfuse (optional)
|
||||
# LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
# LANGFUSE_PUBLIC_KEY=pk-...
|
||||
# LANGFUSE_SECRET_KEY=sk-...
|
||||
# LANGFUSE_PROJECT_ID=...
|
||||
350
cognee/examples/layered_graph_db_example.py
Normal file
350
cognee/examples/layered_graph_db_example.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""
|
||||
Example demonstrating how to use LayeredKnowledgeGraphDP with database adapters.
|
||||
|
||||
This example shows how to:
|
||||
1. Create a layered knowledge graph
|
||||
2. Set a database adapter
|
||||
3. Add nodes, edges, and layers with automatic persistence to the database
|
||||
4. Retrieve graph data from the database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.graph.datapoint_layered_graph import (
|
||||
GraphNode,
|
||||
GraphEdge,
|
||||
GraphLayer,
|
||||
LayeredKnowledgeGraphDP,
|
||||
)
|
||||
from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def retrieve_graph_manually(graph_id, adapter):
|
||||
"""Retrieve a graph manually from the NetworkX adapter."""
|
||||
graph_db = adapter._graph_db
|
||||
if not hasattr(graph_db, "graph") or not graph_db.graph:
|
||||
await graph_db.load_graph_from_file()
|
||||
|
||||
graph_id_str = str(graph_id)
|
||||
logger.info(f"Looking for graph with ID: {graph_id_str}")
|
||||
|
||||
if hasattr(graph_db, "graph") and graph_db.graph.has_node(graph_id_str):
|
||||
# Get the graph node data
|
||||
graph_data = graph_db.graph.nodes[graph_id_str]
|
||||
logger.info(f"Found graph node data: {graph_data}")
|
||||
|
||||
# Create the graph instance
|
||||
graph = LayeredKnowledgeGraphDP(
|
||||
id=graph_id,
|
||||
name=graph_data.get("name", ""),
|
||||
description=graph_data.get("description", ""),
|
||||
metadata=graph_data.get("metadata", {}),
|
||||
)
|
||||
|
||||
# Set the adapter
|
||||
graph.set_adapter(adapter)
|
||||
|
||||
# Find and add all layers, nodes, and edges
|
||||
nx_graph = graph_db.graph
|
||||
|
||||
# Find all layers for this graph
|
||||
logger.info("Finding layers connected to the graph")
|
||||
found_layers = set()
|
||||
for source, target, key in nx_graph.edges(graph_id_str, keys=True):
|
||||
if key == "CONTAINS_LAYER":
|
||||
# Found layer
|
||||
layer_data = nx_graph.nodes[target]
|
||||
layer_id_str = target
|
||||
logger.info(f"Found layer: {layer_id_str} with data: {layer_data}")
|
||||
|
||||
# Convert parent layers
|
||||
parent_layers = []
|
||||
if "parent_layers" in layer_data:
|
||||
try:
|
||||
if isinstance(layer_data["parent_layers"], str):
|
||||
import json
|
||||
|
||||
parent_layers = [
|
||||
UUID(p) for p in json.loads(layer_data["parent_layers"])
|
||||
]
|
||||
elif isinstance(layer_data["parent_layers"], list):
|
||||
parent_layers = [UUID(str(p)) for p in layer_data["parent_layers"]]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing parent layers: {e}")
|
||||
parent_layers = []
|
||||
|
||||
# Create layer
|
||||
try:
|
||||
layer = GraphLayer(
|
||||
id=UUID(layer_id_str),
|
||||
name=layer_data.get("name", ""),
|
||||
description=layer_data.get("description", ""),
|
||||
layer_type=layer_data.get("layer_type", "default"),
|
||||
parent_layers=parent_layers,
|
||||
properties=layer_data.get("properties", {}),
|
||||
metadata=layer_data.get("metadata", {}),
|
||||
)
|
||||
graph.layers[layer.id] = layer
|
||||
found_layers.add(layer_id_str)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating layer object: {e}")
|
||||
|
||||
# Helper function to safely get UUID
|
||||
def safe_uuid(value):
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
try:
|
||||
return UUID(str(value))
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting to UUID: {value} - {e}")
|
||||
return None
|
||||
|
||||
# Find all nodes for this graph
|
||||
logger.info("Finding nodes in the layers")
|
||||
for node_id, node_data in nx_graph.nodes(data=True):
|
||||
# First check if this is a node by its metadata
|
||||
if node_data.get("metadata", {}).get("type") == "GraphNode":
|
||||
# Get the layer ID from the node data
|
||||
if "layer_id" in node_data:
|
||||
layer_id_str = node_data["layer_id"]
|
||||
# Check if this layer ID is in our found layers
|
||||
try:
|
||||
layer_id = safe_uuid(layer_id_str)
|
||||
if layer_id and layer_id in graph.layers:
|
||||
logger.info(f"Found node with ID {node_id} in layer {layer_id_str}")
|
||||
|
||||
# Create the node
|
||||
node = GraphNode(
|
||||
id=safe_uuid(node_id),
|
||||
name=node_data.get("name", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
description=node_data.get("description", ""),
|
||||
properties=node_data.get("properties", {}),
|
||||
layer_id=layer_id,
|
||||
metadata=node_data.get("metadata", {}),
|
||||
)
|
||||
graph.nodes[node.id] = node
|
||||
graph.node_layer_map[node.id] = layer_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node_id}: {e}")
|
||||
|
||||
# Find all edges for this graph
|
||||
logger.info("Finding edges in the layers")
|
||||
for node_id, node_data in nx_graph.nodes(data=True):
|
||||
# First check if this is an edge by its metadata
|
||||
if node_data.get("metadata", {}).get("type") == "GraphEdge":
|
||||
# Get the layer ID from the edge data
|
||||
if "layer_id" in node_data:
|
||||
layer_id_str = node_data["layer_id"]
|
||||
# Check if this layer ID is in our found layers
|
||||
try:
|
||||
layer_id = safe_uuid(layer_id_str)
|
||||
if layer_id and layer_id in graph.layers:
|
||||
source_id = safe_uuid(node_data.get("source_node_id"))
|
||||
target_id = safe_uuid(node_data.get("target_node_id"))
|
||||
|
||||
if source_id and target_id:
|
||||
logger.info(f"Found edge with ID {node_id} in layer {layer_id_str}")
|
||||
|
||||
# Create the edge
|
||||
edge = GraphEdge(
|
||||
id=safe_uuid(node_id),
|
||||
source_node_id=source_id,
|
||||
target_node_id=target_id,
|
||||
relationship_name=node_data.get("relationship_name", ""),
|
||||
properties=node_data.get("properties", {}),
|
||||
layer_id=layer_id,
|
||||
metadata=node_data.get("metadata", {}),
|
||||
)
|
||||
graph.edges[edge.id] = edge
|
||||
graph.edge_layer_map[edge.id] = layer_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing edge {node_id}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Manually retrieved graph with {len(graph.layers)} layers, {len(graph.nodes)} nodes, and {len(graph.edges)} edges"
|
||||
)
|
||||
return graph
|
||||
else:
|
||||
logger.error(f"Could not find graph with ID {graph_id}")
|
||||
return None
|
||||
|
||||
|
||||
def get_nodes_and_edges_from_graph(graph, layer_id):
|
||||
"""
|
||||
Get nodes and edges from a layer in the graph, avoiding database calls.
|
||||
|
||||
Args:
|
||||
graph: The LayeredKnowledgeGraphDP instance
|
||||
layer_id: The UUID of the layer
|
||||
|
||||
Returns:
|
||||
Tuple of (nodes, edges) lists
|
||||
"""
|
||||
nodes = [node for node in graph.nodes.values() if node.layer_id == layer_id]
|
||||
edges = [edge for edge in graph.edges.values() if edge.layer_id == layer_id]
|
||||
return nodes, edges
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("Starting layered graph database example")
|
||||
|
||||
# Get the default graph engine (typically NetworkXAdapter)
|
||||
graph_db = await get_graph_engine()
|
||||
logger.info(f"Using graph database adapter: {type(graph_db).__name__}")
|
||||
|
||||
# Create an adapter using the graph engine
|
||||
adapter = LayeredGraphDBAdapter(graph_db)
|
||||
|
||||
# Create a new empty graph
|
||||
graph = LayeredKnowledgeGraphDP.create_empty(
|
||||
name="Example Database Graph",
|
||||
description="A graph that persists to the database",
|
||||
metadata={
|
||||
"type": "LayeredKnowledgeGraph",
|
||||
"index_fields": ["name"],
|
||||
}, # Ensure proper metadata
|
||||
)
|
||||
logger.info(f"Created graph with ID: {graph.id}")
|
||||
|
||||
# Set the adapter for this graph
|
||||
graph.set_adapter(adapter)
|
||||
|
||||
# Create and add a base layer
|
||||
base_layer = GraphLayer.create(
|
||||
name="Base Layer", description="The foundation layer of the graph", layer_type="base"
|
||||
)
|
||||
graph.add_layer(base_layer)
|
||||
logger.info(f"Added base layer with ID: {base_layer.id}")
|
||||
|
||||
# Create and add a derived layer that extends the base layer
|
||||
derived_layer = GraphLayer.create(
|
||||
name="Derived Layer",
|
||||
description="A layer that extends the base layer",
|
||||
layer_type="derived",
|
||||
parent_layers=[base_layer.id],
|
||||
)
|
||||
graph.add_layer(derived_layer)
|
||||
logger.info(f"Added derived layer with ID: {derived_layer.id}")
|
||||
|
||||
# Create and add nodes to the base layer
|
||||
node1 = GraphNode.create(
|
||||
name="Concept A", node_type="concept", description="A foundational concept"
|
||||
)
|
||||
graph.add_node(node1, base_layer.id)
|
||||
logger.info(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}")
|
||||
|
||||
node2 = GraphNode.create(
|
||||
name="Concept B", node_type="concept", description="Another foundational concept"
|
||||
)
|
||||
graph.add_node(node2, base_layer.id)
|
||||
logger.info(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}")
|
||||
|
||||
# Create and add a node to the derived layer
|
||||
node3 = GraphNode.create(
|
||||
name="Derived Concept",
|
||||
node_type="concept",
|
||||
description="A concept derived from foundational concepts",
|
||||
)
|
||||
graph.add_node(node3, derived_layer.id)
|
||||
logger.info(f"Added node3 with ID: {node3.id} to layer: {derived_layer.id}")
|
||||
|
||||
# Create and add edges between nodes
|
||||
edge1 = GraphEdge.create(
|
||||
source_node_id=node1.id, target_node_id=node2.id, relationship_name="RELATES_TO"
|
||||
)
|
||||
graph.add_edge(edge1, base_layer.id)
|
||||
logger.info(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}")
|
||||
|
||||
edge2 = GraphEdge.create(
|
||||
source_node_id=node1.id, target_node_id=node3.id, relationship_name="EXPANDS_TO"
|
||||
)
|
||||
graph.add_edge(edge2, derived_layer.id)
|
||||
logger.info(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}")
|
||||
|
||||
edge3 = GraphEdge.create(
|
||||
source_node_id=node2.id, target_node_id=node3.id, relationship_name="CONTRIBUTES_TO"
|
||||
)
|
||||
graph.add_edge(edge3, derived_layer.id)
|
||||
logger.info(f"Added edge3 with ID: {edge3.id} between {node2.id} and {node3.id}")
|
||||
|
||||
# Save the graph state to a file for NetworkXAdapter
|
||||
if hasattr(graph_db, "save_graph_to_file"):
|
||||
logger.info(f"Saving graph to file: {getattr(graph_db, 'filename', 'unknown')}")
|
||||
await graph_db.save_graph_to_file()
|
||||
|
||||
# Persist the entire graph to the database
|
||||
# This is optional since the graph is already being persisted incrementally
|
||||
# when add_layer, add_node, and add_edge are called
|
||||
logger.info("Persisting entire graph to database")
|
||||
graph_id = await graph.persist()
|
||||
logger.info(f"Graph persisted with ID: {graph_id}")
|
||||
|
||||
# Check if the graph exists in the database
|
||||
if hasattr(graph_db, "graph"):
|
||||
logger.info(f"Checking if graph exists in memory: {graph_db.graph.has_node(str(graph.id))}")
|
||||
|
||||
# Check the node data
|
||||
if graph_db.graph.has_node(str(graph.id)):
|
||||
node_data = graph_db.graph.nodes[str(graph.id)]
|
||||
logger.info(f"Graph node data: {node_data}")
|
||||
|
||||
# List all nodes in the graph
|
||||
logger.info("Nodes in the graph:")
|
||||
for node_id, node_data in graph_db.graph.nodes(data=True):
|
||||
logger.info(
|
||||
f" Node {node_id}: type={node_data.get('metadata', {}).get('type', 'unknown')}"
|
||||
)
|
||||
|
||||
# Try to retrieve the graph using our from_database method
|
||||
retrieved_graph = None
|
||||
try:
|
||||
logger.info(f"Retrieving graph from database with ID: {graph.id}")
|
||||
retrieved_graph = await LayeredKnowledgeGraphDP.from_database(graph.id, adapter)
|
||||
logger.info(f"Retrieved graph: {retrieved_graph}")
|
||||
logger.info(
|
||||
f"Retrieved {len(retrieved_graph.layers)} layers, {len(retrieved_graph.nodes)} nodes, and {len(retrieved_graph.edges)} edges"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error retrieving graph using from_database: {str(e)}")
|
||||
|
||||
# Try using manual retrieval as a fallback
|
||||
logger.info("Trying manual retrieval as a fallback")
|
||||
retrieved_graph = await retrieve_graph_manually(graph.id, adapter)
|
||||
|
||||
if retrieved_graph:
|
||||
logger.info(f"Successfully retrieved graph manually: {retrieved_graph}")
|
||||
else:
|
||||
logger.error("Failed to retrieve graph manually")
|
||||
return
|
||||
|
||||
# Use our helper function to get nodes and edges
|
||||
if retrieved_graph:
|
||||
# Get nodes in the base layer
|
||||
base_nodes, base_edges = get_nodes_and_edges_from_graph(retrieved_graph, base_layer.id)
|
||||
logger.info(f"Nodes in base layer: {[node.name for node in base_nodes]}")
|
||||
logger.info(f"Edges in base layer: {[edge.relationship_name for edge in base_edges]}")
|
||||
|
||||
# Get nodes in the derived layer
|
||||
derived_nodes, derived_edges = get_nodes_and_edges_from_graph(
|
||||
retrieved_graph, derived_layer.id
|
||||
)
|
||||
logger.info(f"Nodes in derived layer: {[node.name for node in derived_nodes]}")
|
||||
logger.info(f"Edges in derived layer: {[edge.relationship_name for edge in derived_edges]}")
|
||||
else:
|
||||
logger.error("No graph was retrieved, cannot display nodes and edges")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
201
cognee/examples/simplified_layered_graph_example.py
Normal file
201
cognee/examples/simplified_layered_graph_example.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
Example demonstrating how to use the simplified LayeredKnowledgeGraph with database adapters.
|
||||
|
||||
This example shows how to:
|
||||
1. Create a layered knowledge graph
|
||||
2. Add nodes, edges, and layers
|
||||
3. Retrieve layer data
|
||||
4. Work with cumulative layers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
import os
|
||||
|
||||
from cognee.modules.graph.simplified_layered_graph import (
|
||||
LayeredKnowledgeGraph,
|
||||
GraphNode,
|
||||
GraphEdge,
|
||||
GraphLayer,
|
||||
)
|
||||
from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
print("Starting simplified layered graph example")
|
||||
|
||||
# Initialize file path for the NetworkXAdapter
|
||||
db_dir = os.path.join(os.path.expanduser("~"), "cognee/cognee/.cognee_system/databases")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
db_file = os.path.join(db_dir, "cognee_graph.pkl")
|
||||
|
||||
# Use NetworkXAdapter for the graph database
|
||||
adapter = NetworkXAdapter(filename=db_file)
|
||||
|
||||
# Initialize the adapter by creating or loading the graph
|
||||
if not os.path.exists(db_file):
|
||||
await adapter.create_empty_graph(db_file)
|
||||
await adapter.load_graph_from_file()
|
||||
|
||||
print(f"Using graph database adapter: {adapter.__class__.__name__}")
|
||||
|
||||
# Create an empty graph
|
||||
graph = LayeredKnowledgeGraph.create_empty("Test Graph")
|
||||
graph.set_adapter(LayeredGraphDBAdapter(adapter))
|
||||
print(f"Created graph with ID: {graph.id}")
|
||||
|
||||
# Create layers
|
||||
base_layer = await graph.add_layer(
|
||||
name="Base Layer", description="The foundation layer with base concepts", layer_type="base"
|
||||
)
|
||||
print(f"Added base layer with ID: {base_layer.id}")
|
||||
|
||||
intermediate_layer = await graph.add_layer(
|
||||
name="Intermediate Layer",
|
||||
description="Layer that builds on base concepts",
|
||||
layer_type="intermediate",
|
||||
parent_layers=[base_layer.id],
|
||||
)
|
||||
print(f"Added intermediate layer with ID: {intermediate_layer.id}")
|
||||
|
||||
derived_layer = await graph.add_layer(
|
||||
name="Derived Layer",
|
||||
description="Final layer with derived concepts",
|
||||
layer_type="derived",
|
||||
parent_layers=[intermediate_layer.id],
|
||||
)
|
||||
print(f"Added derived layer with ID: {derived_layer.id}")
|
||||
|
||||
# Add nodes to layers
|
||||
node1 = await graph.add_node(
|
||||
name="Base Concept A",
|
||||
node_type="concept",
|
||||
properties={"importance": "high"},
|
||||
metadata={"source": "example"},
|
||||
layer_id=base_layer.id,
|
||||
)
|
||||
print(f"Added node1 with ID: {node1.id} to layer: {base_layer.id}")
|
||||
|
||||
node2 = await graph.add_node(
|
||||
name="Base Concept B",
|
||||
node_type="concept",
|
||||
properties={"importance": "medium"},
|
||||
metadata={"source": "example"},
|
||||
layer_id=base_layer.id,
|
||||
)
|
||||
print(f"Added node2 with ID: {node2.id} to layer: {base_layer.id}")
|
||||
|
||||
node3 = await graph.add_node(
|
||||
name="Intermediate Concept",
|
||||
node_type="concept",
|
||||
properties={"derived_from": ["Base Concept A"]},
|
||||
metadata={"source": "example"},
|
||||
layer_id=intermediate_layer.id,
|
||||
)
|
||||
print(f"Added node3 with ID: {node3.id} to layer: {intermediate_layer.id}")
|
||||
|
||||
node4 = await graph.add_node(
|
||||
name="Derived Concept",
|
||||
node_type="concept",
|
||||
properties={"derived_from": ["Intermediate Concept"]},
|
||||
metadata={"source": "example"},
|
||||
layer_id=derived_layer.id,
|
||||
)
|
||||
print(f"Added node4 with ID: {node4.id} to layer: {derived_layer.id}")
|
||||
|
||||
# Add edges between nodes
|
||||
edge1 = await graph.add_edge(
|
||||
source_id=node1.id,
|
||||
target_id=node2.id,
|
||||
edge_type="RELATES_TO",
|
||||
properties={"strength": "high"},
|
||||
metadata={"source": "example"},
|
||||
layer_id=base_layer.id,
|
||||
)
|
||||
print(f"Added edge1 with ID: {edge1.id} between {node1.id} and {node2.id}")
|
||||
|
||||
edge2 = await graph.add_edge(
|
||||
source_id=node1.id,
|
||||
target_id=node3.id,
|
||||
edge_type="SUPPORTS",
|
||||
properties={"confidence": 0.9},
|
||||
metadata={"source": "example"},
|
||||
layer_id=intermediate_layer.id,
|
||||
)
|
||||
print(f"Added edge2 with ID: {edge2.id} between {node1.id} and {node3.id}")
|
||||
|
||||
edge3 = await graph.add_edge(
|
||||
source_id=node3.id,
|
||||
target_id=node4.id,
|
||||
edge_type="EXTENDS",
|
||||
properties={"confidence": 0.8},
|
||||
metadata={"source": "example"},
|
||||
layer_id=derived_layer.id,
|
||||
)
|
||||
print(f"Added edge3 with ID: {edge3.id} between {node3.id} and {node4.id}")
|
||||
|
||||
# Save the graph to the database
|
||||
# The graph is automatically saved when nodes and edges are added,
|
||||
# but for NetworkXAdapter we'll save the file explicitly
|
||||
if hasattr(adapter, "save_graph_to_file"):
|
||||
await adapter.save_graph_to_file(adapter.filename)
|
||||
print(f"Saving graph to file: {adapter.filename}")
|
||||
|
||||
# Retrieve all layers
|
||||
layers = await graph.get_layers()
|
||||
print(f"Retrieved {len(layers)} layers")
|
||||
|
||||
# Load the graph from the database
|
||||
print(f"Loading graph with ID: {graph.id} from database")
|
||||
# Create a new graph instance from the database
|
||||
loaded_graph = LayeredKnowledgeGraph(id=graph.id, name="Test Graph")
|
||||
loaded_graph.set_adapter(LayeredGraphDBAdapter(adapter))
|
||||
# Load layers, which will also load nodes and edges
|
||||
loaded_layers = await loaded_graph.get_layers()
|
||||
print(f"Successfully loaded graph: {loaded_graph} with {len(loaded_layers)} layers")
|
||||
|
||||
# Display contents of each layer
|
||||
print("\n===== Individual Layer Contents =====")
|
||||
for layer in layers:
|
||||
# Get nodes and edges in the layer
|
||||
nodes = await graph.get_nodes_in_layer(layer.id)
|
||||
edges = await graph.get_edges_in_layer(layer.id)
|
||||
|
||||
# Print summary
|
||||
print(f"Nodes in {layer.name.lower()} layer: {[node.name for node in nodes]}")
|
||||
print(f"Edges in {layer.name.lower()} layer: {[edge.edge_type for edge in edges]}")
|
||||
|
||||
# Display cumulative layer views
|
||||
print("\n===== Cumulative Layer Views =====")
|
||||
|
||||
# Intermediate layer - should include base layer nodes/edges
|
||||
print("\nCumulative layer graph for intermediate layer:")
|
||||
int_nodes, int_edges = await graph.get_cumulative_layer_graph(intermediate_layer.id)
|
||||
print(f"Intermediate cumulative nodes: {[node.name for node in int_nodes]}")
|
||||
print(f"Intermediate cumulative edges: {[edge.edge_type for edge in int_edges]}")
|
||||
|
||||
# Derived layer - should include all nodes/edges
|
||||
print("\nCumulative layer graph for derived layer:")
|
||||
derived_nodes, derived_edges = await graph.get_cumulative_layer_graph(derived_layer.id)
|
||||
print(f"Derived cumulative nodes: {[node.name for node in derived_nodes]}")
|
||||
print(f"Derived cumulative edges: {[edge.edge_type for edge in derived_edges]}")
|
||||
|
||||
# Test helper methods
|
||||
print("\n===== Helper Method Results =====")
|
||||
base_nodes = await graph.get_nodes_in_layer(base_layer.id)
|
||||
base_edges = await graph.get_edges_in_layer(base_layer.id)
|
||||
print(f"Base layer contains {len(base_nodes)} nodes and {len(base_edges)} edges")
|
||||
|
||||
print("Example complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
106
cognee/modules/graph/README.md
Normal file
106
cognee/modules/graph/README.md
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
# Layered Knowledge Graph
|
||||
|
||||
This module provides a simplified implementation of a layered knowledge graph, which allows organizing nodes and edges into hierarchical layers.
|
||||
|
||||
## Features
|
||||
|
||||
- **Hierarchical Layer Structure**: Organize your graph into layers with parent-child relationships
|
||||
- **Cumulative Views**: Access nodes and edges from a layer and all its parent layers
|
||||
- **Adapter-based Design**: Connect to different database backends using adapter pattern
|
||||
- **NetworkX Integration**: Built-in support for NetworkX graph database
|
||||
- **Type Safety**: Pydantic models ensure type safety and data validation
|
||||
- **Async API**: All methods are async for better performance
|
||||
|
||||
## Components
|
||||
|
||||
- **GraphNode**: A node in the graph with a name, type, properties, and metadata
|
||||
- **GraphEdge**: An edge connecting two nodes with an edge type, properties, and metadata
|
||||
- **GraphLayer**: A layer in the graph that can contain nodes and edges, and can have parent layers
|
||||
- **LayeredKnowledgeGraph**: The main graph class that manages layers, nodes, and edges
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from cognee.modules.graph.simplified_layered_graph import LayeredKnowledgeGraph
|
||||
from cognee.modules.graph.enhanced_layered_graph_adapter import LayeredGraphDBAdapter
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
|
||||
async def main():
|
||||
# Initialize adapter
|
||||
adapter = NetworkXAdapter(filename="graph.pkl")
|
||||
await adapter.create_empty_graph("graph.pkl")
|
||||
|
||||
# Create graph
|
||||
graph = LayeredKnowledgeGraph.create_empty("My Knowledge Graph")
|
||||
graph.set_adapter(LayeredGraphDBAdapter(adapter))
|
||||
|
||||
# Add layers with parent-child relationships
|
||||
base_layer = await graph.add_layer(
|
||||
name="Base Layer",
|
||||
description="Foundation concepts",
|
||||
layer_type="base"
|
||||
)
|
||||
|
||||
derived_layer = await graph.add_layer(
|
||||
name="Derived Layer",
|
||||
description="Concepts built upon the base layer",
|
||||
layer_type="derived",
|
||||
parent_layers=[base_layer.id] # Parent-child relationship
|
||||
)
|
||||
|
||||
# Add nodes to layers
|
||||
node1 = await graph.add_node(
|
||||
name="Concept A",
|
||||
node_type="concept",
|
||||
properties={"importance": "high"},
|
||||
layer_id=base_layer.id
|
||||
)
|
||||
|
||||
node2 = await graph.add_node(
|
||||
name="Concept B",
|
||||
node_type="concept",
|
||||
properties={"importance": "medium"},
|
||||
layer_id=derived_layer.id
|
||||
)
|
||||
|
||||
# Connect nodes with an edge
|
||||
edge = await graph.add_edge(
|
||||
source_id=node1.id,
|
||||
target_id=node2.id,
|
||||
edge_type="RELATES_TO",
|
||||
properties={"strength": "high"},
|
||||
layer_id=derived_layer.id
|
||||
)
|
||||
|
||||
# Get cumulative view (including parent layers)
|
||||
nodes, edges = await graph.get_cumulative_layer_graph(derived_layer.id)
|
||||
|
||||
print(f"Nodes in cumulative view: {[n.name for n in nodes]}")
|
||||
print(f"Edges in cumulative view: {[e.edge_type for e in edges]}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Design Improvements
|
||||
|
||||
The simplified layered graph implementation offers several improvements over the previous approach:
|
||||
|
||||
1. **Clear Separation of Concerns**: In-memory operations vs. database operations
|
||||
2. **More Intuitive API**: Methods have clear, consistent signatures
|
||||
3. **Better Error Handling**: Comprehensive validation and error reporting
|
||||
4. **Enhanced Debugging**: Detailed logging throughout
|
||||
5. **Improved Caching**: Local caches reduce database load
|
||||
6. **Method Naming Consistency**: All methods follow consistent naming conventions
|
||||
7. **Reduced Complexity**: Simpler implementation with equivalent functionality
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Always use the adapter pattern for database operations
|
||||
- Use the provided factory methods for creating nodes and edges
|
||||
- Leverage parent-child relationships for organizing related concepts
|
||||
- Utilize cumulative views to access inherited nodes and edges
|
||||
- Consider layer types for additional semantic meaning
|
||||
- Use properties and metadata for storing additional information
|
||||
37
cognee/notebooks/github_analysis_step_by_step.ipynb
Normal file
37
cognee/notebooks/github_analysis_step_by_step.ipynb
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
|
@ -15,7 +15,7 @@ import inspect
|
|||
import json
|
||||
|
||||
|
||||
async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
|
||||
async def ingest_data(data: Any, dataset_name: str, user: User, node_set: Optional[List[str]] = None):
|
||||
destination = get_dlt_destination()
|
||||
|
||||
pipeline = dlt.pipeline(
|
||||
|
|
@ -44,10 +44,10 @@ async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optiona
|
|||
"mime_type": file_metadata["mime_type"],
|
||||
"content_hash": file_metadata["content_hash"],
|
||||
"owner_id": str(user.id),
|
||||
"node_set": json.dumps(NodeSet) if NodeSet else None,
|
||||
"node_set": json.dumps(node_set) if node_set else None,
|
||||
}
|
||||
|
||||
async def store_data_to_dataset(data: Any, dataset_name: str, user: User, NodeSet: Optional[List[str]] = None):
|
||||
async def store_data_to_dataset(data: Any, dataset_name: str, user: User, node_set: Optional[List[str]] = None):
|
||||
if not isinstance(data, list):
|
||||
# Convert data to a list as we work with lists further down.
|
||||
data = [data]
|
||||
|
|
@ -84,8 +84,8 @@ async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optiona
|
|||
).scalar_one_or_none()
|
||||
|
||||
ext_metadata = get_external_metadata_dict(data_item)
|
||||
if NodeSet:
|
||||
ext_metadata["node_set"] = NodeSet
|
||||
if node_set:
|
||||
ext_metadata["node_set"] = node_set
|
||||
|
||||
if data_point is not None:
|
||||
data_point.name = file_metadata["name"]
|
||||
|
|
@ -95,7 +95,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optiona
|
|||
data_point.owner_id = user.id
|
||||
data_point.content_hash = file_metadata["content_hash"]
|
||||
data_point.external_metadata = ext_metadata
|
||||
data_point.node_set = json.dumps(NodeSet) if NodeSet else None
|
||||
data_point.node_set = json.dumps(node_set) if node_set else None
|
||||
await session.merge(data_point)
|
||||
else:
|
||||
data_point = Data(
|
||||
|
|
@ -107,7 +107,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optiona
|
|||
owner_id=user.id,
|
||||
content_hash=file_metadata["content_hash"],
|
||||
external_metadata=ext_metadata,
|
||||
node_set=json.dumps(NodeSet) if NodeSet else None,
|
||||
node_set=json.dumps(node_set) if node_set else None,
|
||||
token_count=-1,
|
||||
)
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User, NodeSet: Optiona
|
|||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
file_paths = await store_data_to_dataset(data, dataset_name, user, NodeSet)
|
||||
file_paths = await store_data_to_dataset(data, dataset_name, user, node_set)
|
||||
|
||||
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||
# can't be used inside the pipeline
|
||||
|
|
|
|||
88
cognee/tasks/node_set/apply_node_set.py
Normal file
88
cognee/tasks/node_set/apply_node_set.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
import json
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
from typing import List, Any
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def apply_node_set(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
"""
|
||||
Apply NodeSet values from the relational store to DataPoint instances.
|
||||
|
||||
This task fetches the NodeSet values from the Data model in the relational database
|
||||
and applies them to the corresponding DataPoint instances.
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint instances to process
|
||||
|
||||
Returns:
|
||||
List of updated DataPoint instances with NodeSet values applied
|
||||
"""
|
||||
logger.info(f"Applying NodeSet values to {len(data_points)} DataPoints")
|
||||
|
||||
if not data_points:
|
||||
return data_points
|
||||
|
||||
# Create a map of data_point IDs for efficient lookup
|
||||
data_point_map = {str(dp.id): dp for dp in data_points}
|
||||
|
||||
# Get the database engine
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
# Get session (handles both sync and async cases for testing)
|
||||
session = db_engine.get_async_session()
|
||||
|
||||
try:
|
||||
# Handle both AsyncMock and actual async context manager for testing
|
||||
if hasattr(session, "__aenter__"):
|
||||
# It's a real async context manager
|
||||
async with session as sess:
|
||||
await _process_data_points(sess, data_point_map)
|
||||
else:
|
||||
# It's an AsyncMock in tests
|
||||
await _process_data_points(session, data_point_map)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying NodeSet values: {e}")
|
||||
|
||||
return data_points
|
||||
|
||||
|
||||
async def _process_data_points(session, data_point_map):
|
||||
"""
|
||||
Process data points with the given session.
|
||||
|
||||
This helper function handles the actual database query and NodeSet application.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
data_point_map: Map of data point IDs to DataPoint objects
|
||||
"""
|
||||
# Get all data points from the Data table that have node_set values
|
||||
# and correspond to the data_points we're processing
|
||||
data_ids = list(data_point_map.keys())
|
||||
|
||||
query = select(Data).where(Data.id.in_(data_ids), Data.node_set.is_not(None))
|
||||
|
||||
result = await session.execute(query)
|
||||
data_records = result.scalars().all()
|
||||
|
||||
# Apply NodeSet values to corresponding DataPoint instances
|
||||
for data_record in data_records:
|
||||
data_point_id = str(data_record.id)
|
||||
if data_point_id in data_point_map and data_record.node_set:
|
||||
# Parse the JSON string to get the NodeSet
|
||||
try:
|
||||
node_set = json.loads(data_record.node_set)
|
||||
data_point_map[data_point_id].NodeSet = node_set
|
||||
logger.debug(f"Applied NodeSet {node_set} to DataPoint {data_point_id}")
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse NodeSet JSON for DataPoint {data_point_id}")
|
||||
continue
|
||||
|
||||
logger.info(f"Successfully applied NodeSet values to DataPoints")
|
||||
1
cognee/tests/integration/node_set/__init__.py
Normal file
1
cognee/tests/integration/node_set/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Node Set integration tests
|
||||
180
cognee/tests/integration/node_set/node_set_integration_test.py
Normal file
180
cognee/tests/integration/node_set/node_set_integration_test.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.tasks.node_set import apply_node_set
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
class TestDocument(DataPoint):
|
||||
"""Test document model for NodeSet testing."""
|
||||
|
||||
content: str
|
||||
metadata: dict = {"index_fields": ["content"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_set_add_to_cognify_workflow():
|
||||
"""
|
||||
Test the full NodeSet workflow from add to cognify.
|
||||
|
||||
This test verifies that:
|
||||
1. NodeSet data can be added using cognee.add
|
||||
2. The NodeSet data is stored in the relational database
|
||||
3. The apply_node_set task can retrieve the NodeSet and apply it to DataPoints
|
||||
"""
|
||||
# Clean up any existing data
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Create test data
|
||||
test_content = "This is test content"
|
||||
test_node_set = ["node1", "node2", "node3"]
|
||||
dataset_name = f"test_dataset_{uuid4()}"
|
||||
|
||||
# Create database tables
|
||||
await create_db_and_tables()
|
||||
|
||||
# Mock functions to avoid external dependencies
|
||||
mock_add_data_points = AsyncMock()
|
||||
mock_add_data_points.return_value = []
|
||||
|
||||
# Create a temporary file for the test
|
||||
temp_file_path = "/tmp/test_node_set.txt"
|
||||
with open(temp_file_path, "w") as f:
|
||||
f.write(test_content)
|
||||
|
||||
try:
|
||||
# Mock ingest_data to capture and verify NodeSet
|
||||
original_ingest_data = cognee.tasks.ingestion.ingest_data
|
||||
|
||||
async def mock_ingest_data(*args, **kwargs):
|
||||
# Call the original function but capture the NodeSet parameter
|
||||
assert len(args) >= 3, "Expected at least 3 arguments"
|
||||
assert args[1] == dataset_name, f"Expected dataset name {dataset_name}"
|
||||
assert kwargs.get("NodeSet") == test_node_set or args[3] == test_node_set, (
|
||||
"NodeSet not passed correctly"
|
||||
)
|
||||
return await original_ingest_data(*args, **kwargs)
|
||||
|
||||
# Replace the ingest_data function temporarily
|
||||
with patch("cognee.tasks.ingestion.ingest_data", side_effect=mock_ingest_data):
|
||||
# Call the add function with NodeSet
|
||||
await cognee.add(temp_file_path, dataset_name, NodeSet=test_node_set)
|
||||
|
||||
# Create test DataPoint for apply_node_set to process
|
||||
test_document = TestDocument(content=test_content)
|
||||
|
||||
# Test the apply_node_set task
|
||||
with patch(
|
||||
"cognee.tasks.node_set.apply_node_set.get_relational_engine"
|
||||
) as mock_get_engine:
|
||||
# Setup mock engine and session
|
||||
mock_session = AsyncMock()
|
||||
mock_engine = AsyncMock()
|
||||
|
||||
# Properly mock the async context manager
|
||||
@asynccontextmanager
|
||||
async def mock_get_session():
|
||||
try:
|
||||
yield mock_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
mock_engine.get_async_session.return_value = mock_get_session()
|
||||
mock_get_engine.return_value = mock_engine
|
||||
|
||||
# Create a mock Data object with our NodeSet
|
||||
class MockData:
|
||||
def __init__(self, id, node_set):
|
||||
self.id = id
|
||||
self.node_set = node_set
|
||||
|
||||
mock_data = MockData(test_document.id, json.dumps(test_node_set))
|
||||
|
||||
# Setup the mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [mock_data]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
# Run the apply_node_set task
|
||||
result = await apply_node_set([test_document])
|
||||
|
||||
# Verify the NodeSet was applied
|
||||
assert len(result) == 1
|
||||
assert result[0].NodeSet == test_node_set
|
||||
|
||||
# Verify the mock interactions
|
||||
mock_get_engine.assert_called_once()
|
||||
mock_engine.get_async_session.assert_called_once()
|
||||
finally:
|
||||
# Clean up the temporary file
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
# Clean up after the test
|
||||
await cognee.prune.prune_data()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_set_in_cognify_pipeline():
|
||||
"""
|
||||
Test the integration of apply_node_set task in the cognify pipeline.
|
||||
|
||||
This test verifies that the apply_node_set task works correctly
|
||||
when run as part of a pipeline with other tasks.
|
||||
"""
|
||||
# Create test data
|
||||
test_documents = [TestDocument(content="Document 1"), TestDocument(content="Document 2")]
|
||||
|
||||
# Create a simple mock task that just passes data through
|
||||
async def mock_task(data):
|
||||
for item in data:
|
||||
yield item
|
||||
|
||||
# Mock the apply_node_set function to verify it's called with the right data
|
||||
original_apply_node_set = apply_node_set
|
||||
|
||||
apply_node_set_called = False
|
||||
|
||||
async def mock_apply_node_set(data_points):
|
||||
nonlocal apply_node_set_called
|
||||
apply_node_set_called = True
|
||||
|
||||
# Verify the input
|
||||
assert len(data_points) == 2
|
||||
assert all(isinstance(dp, TestDocument) for dp in data_points)
|
||||
|
||||
# Apply NodeSet to demonstrate it worked
|
||||
for dp in data_points:
|
||||
dp.NodeSet = ["test_node"]
|
||||
|
||||
return data_points
|
||||
|
||||
# Create a pipeline with our tasks
|
||||
with patch("cognee.tasks.node_set.apply_node_set", side_effect=mock_apply_node_set):
|
||||
pipeline = run_tasks(
|
||||
tasks=[
|
||||
Task(mock_task), # First task passes data through
|
||||
Task(apply_node_set), # Second task applies NodeSet
|
||||
],
|
||||
data=test_documents,
|
||||
)
|
||||
|
||||
# Process all results from the pipeline
|
||||
results = []
|
||||
async for result in pipeline:
|
||||
results.extend(result)
|
||||
|
||||
# Verify results
|
||||
assert apply_node_set_called, "apply_node_set was not called"
|
||||
assert len(results) == 2
|
||||
assert all(dp.NodeSet == ["test_node"] for dp in results)
|
||||
1
cognee/tests/tasks/node_set/__init__.py
Normal file
1
cognee/tests/tasks/node_set/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Node Set task tests
|
||||
132
cognee/tests/tasks/node_set/apply_node_set_test.py
Normal file
132
cognee/tests/tasks/node_set/apply_node_set_test.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.tasks.node_set.apply_node_set import apply_node_set
|
||||
|
||||
|
||||
class TestDataPoint(DataPoint):
|
||||
"""Test DataPoint model for testing apply_node_set task."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set():
|
||||
"""Test that apply_node_set applies NodeSet values from the relational store to DataPoint instances."""
|
||||
# Create test data
|
||||
dp_id1 = uuid4()
|
||||
dp_id2 = uuid4()
|
||||
|
||||
# Create test DataPoint instances
|
||||
data_points = [
|
||||
TestDataPoint(id=dp_id1, name="Test 1", description="Description 1"),
|
||||
TestDataPoint(id=dp_id2, name="Test 2", description="Description 2"),
|
||||
]
|
||||
|
||||
# Create mock Data records that would be returned from the database
|
||||
node_set1 = ["node1", "node2"]
|
||||
node_set2 = ["node3", "node4", "node5"]
|
||||
|
||||
# Create a mock implementation of _process_data_points
|
||||
async def mock_process_data_points(session, data_point_map):
|
||||
# Apply NodeSet directly to the DataPoints
|
||||
for dp_id, node_set in [(dp_id1, node_set1), (dp_id2, node_set2)]:
|
||||
dp_id_str = str(dp_id)
|
||||
if dp_id_str in data_point_map:
|
||||
data_point_map[dp_id_str].NodeSet = node_set
|
||||
|
||||
# Patch the necessary functions
|
||||
with (
|
||||
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
|
||||
patch(
|
||||
"cognee.tasks.node_set.apply_node_set._process_data_points",
|
||||
side_effect=mock_process_data_points,
|
||||
),
|
||||
):
|
||||
# Call the function being tested
|
||||
result = await apply_node_set(data_points)
|
||||
|
||||
# Verify the results
|
||||
assert len(result) == 2
|
||||
|
||||
# Check that NodeSet values were applied correctly
|
||||
assert result[0].NodeSet == node_set1
|
||||
assert result[1].NodeSet == node_set2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set_empty_list():
|
||||
"""Test apply_node_set with an empty list of DataPoints."""
|
||||
result = await apply_node_set([])
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set_no_matching_data():
|
||||
"""Test apply_node_set when there are no matching Data records."""
|
||||
# Create test data
|
||||
dp_id = uuid4()
|
||||
|
||||
# Create test DataPoint instances
|
||||
data_points = [TestDataPoint(id=dp_id, name="Test", description="Description")]
|
||||
|
||||
# Create a mock implementation of _process_data_points that doesn't modify any DataPoints
|
||||
async def mock_process_data_points(session, data_point_map):
|
||||
# Don't modify anything - simulating no matching records
|
||||
pass
|
||||
|
||||
# Patch the necessary functions
|
||||
with (
|
||||
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
|
||||
patch(
|
||||
"cognee.tasks.node_set.apply_node_set._process_data_points",
|
||||
side_effect=mock_process_data_points,
|
||||
),
|
||||
):
|
||||
# Call the function being tested
|
||||
result = await apply_node_set(data_points)
|
||||
|
||||
# Verify the results - NodeSet should remain None
|
||||
assert len(result) == 1
|
||||
assert result[0].NodeSet is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_node_set_invalid_json():
|
||||
"""Test apply_node_set when there's invalid JSON in the node_set column."""
|
||||
# Create test data
|
||||
dp_id = uuid4()
|
||||
|
||||
# Create test DataPoint instances
|
||||
data_points = [TestDataPoint(id=dp_id, name="Test", description="Description")]
|
||||
|
||||
# Create a mock implementation of _process_data_points that throws the appropriate error
|
||||
async def mock_process_data_points(session, data_point_map):
|
||||
# Simulate the JSONDecodeError by logging a warning
|
||||
from cognee.tasks.node_set.apply_node_set import logger
|
||||
|
||||
logger.warning(f"Failed to parse NodeSet JSON for DataPoint {str(dp_id)}")
|
||||
|
||||
# Patch the necessary functions
|
||||
with (
|
||||
patch("cognee.tasks.node_set.apply_node_set.get_relational_engine"),
|
||||
patch(
|
||||
"cognee.tasks.node_set.apply_node_set._process_data_points",
|
||||
side_effect=mock_process_data_points,
|
||||
),
|
||||
patch("cognee.tasks.node_set.apply_node_set.logger") as mock_logger,
|
||||
):
|
||||
# Call the function being tested
|
||||
result = await apply_node_set(data_points)
|
||||
|
||||
# Verify the results - NodeSet should remain None
|
||||
assert len(result) == 1
|
||||
assert result[0].NodeSet is None
|
||||
|
||||
# Verify logger warning was called
|
||||
mock_logger.warning.assert_called_once()
|
||||
Loading…
Add table
Reference in a new issue