LightRAG/tests/test_graph_storage.py

1187 lines
55 KiB
Python

#!/usr/bin/env python
"""
General-purpose graph storage test program.
This program selects the graph storage type to use based on the LIGHTRAG_GRAPH_STORAGE configuration in .env,
and tests its basic and advanced operations.
Supported graph storage types include:
- NetworkXStorage
- Neo4JStorage
- MongoDBStorage
- PGGraphStorage
- MemgraphStorage
"""
import asyncio
import os
import sys
import importlib
import numpy as np
from dotenv import load_dotenv
from ascii_colors import ASCIIColors
# Add the project root directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lightrag.types import KnowledgeGraph
from lightrag.kg import (
STORAGE_IMPLEMENTATIONS,
STORAGE_ENV_REQUIREMENTS,
STORAGES,
verify_storage_implementation,
)
from lightrag.kg.shared_storage import initialize_share_data
from lightrag.constants import GRAPH_FIELD_SEP
# Mock embedding function that returns random vectors
async def mock_embedding_func(texts):
return np.random.rand(len(texts), 10) # Return 10-dimensional random vectors
def check_env_file():
"""
Check if the .env file exists and issue a warning if it does not.
Returns True to continue execution, False to exit.
"""
if not os.path.exists(".env"):
warning_msg = "Warning: .env file not found in the current directory. This may affect storage configuration loading."
ASCIIColors.yellow(warning_msg)
# Check if running in an interactive terminal
if sys.stdin.isatty():
response = input("Do you want to continue? (yes/no): ")
if response.lower() != "yes":
ASCIIColors.red("Test program cancelled.")
return False
return True
async def initialize_graph_storage():
"""
Initialize the corresponding graph storage instance based on environment variables.
Returns the initialized storage instance.
"""
# Get the graph storage type from environment variables
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
# Validate the storage type
try:
verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
except ValueError as e:
ASCIIColors.red(f"Error: {str(e)}")
ASCIIColors.yellow(
f"Supported graph storage types: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
return None
# Check for required environment variables
required_env_vars = STORAGE_ENV_REQUIREMENTS.get(graph_storage_type, [])
missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
if missing_env_vars:
ASCIIColors.red(
f"Error: {graph_storage_type} requires the following environment variables, but they are not set: {', '.join(missing_env_vars)}"
)
return None
# Dynamically import the corresponding module
module_path = STORAGES.get(graph_storage_type)
if not module_path:
ASCIIColors.red(f"Error: Module path for {graph_storage_type} not found.")
return None
try:
module = importlib.import_module(module_path, package="lightrag")
storage_class = getattr(module, graph_storage_type)
except (ImportError, AttributeError) as e:
ASCIIColors.red(f"Error: Failed to import {graph_storage_type}: {str(e)}")
return None
# Initialize the storage instance
global_config = {
"embedding_batch_num": 10, # Batch size
"vector_db_storage_cls_kwargs": {
"cosine_better_than_threshold": 0.5 # Cosine similarity threshold
},
"working_dir": os.environ.get("WORKING_DIR", "./rag_storage"), # Working directory
}
# If using NetworkXStorage, initialize shared_storage first
if graph_storage_type == "NetworkXStorage":
initialize_share_data() # Use single-process mode
try:
storage = storage_class(
namespace="test_graph",
workspace="test_workspace",
global_config=global_config,
embedding_func=mock_embedding_func,
)
# Initialize the connection
await storage.initialize()
return storage
except Exception as e:
ASCIIColors.red(f"Error: Failed to initialize {graph_storage_type}: {str(e)}")
return None
async def test_graph_basic(storage):
"""
Test basic graph database operations:
1. Use upsert_node to insert two nodes.
2. Use upsert_edge to insert an edge connecting the two nodes.
3. Use get_node to read a node.
4. Use get_edge to read an edge.
"""
try:
# 1. Insert the first node
node1_id = "Artificial Intelligence"
node1_data = {
"entity_id": node1_id,
"description": "Artificial intelligence is a branch of computer science that aims to understand the essence of intelligence and produce a new kind of intelligent machine that can react in a manner similar to human intelligence.",
"keywords": "AI,Machine Learning,Deep Learning",
"entity_type": "Technology Field",
}
print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# 2. Insert the second node
node2_id = "Machine Learning"
node2_data = {
"entity_id": node2_id,
"description": "Machine learning is a branch of artificial intelligence that uses statistical methods to enable computer systems to learn without being explicitly programmed.",
"keywords": "Supervised Learning,Unsupervised Learning,Reinforcement Learning",
"entity_type": "Technology Field",
}
print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# 3. Insert the connecting edge
edge_data = {
"relationship": "includes",
"weight": 1.0,
"description": "The field of artificial intelligence includes the subfield of machine learning.",
}
print(f"Inserting edge: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge_data)
# 4. Read node properties
print(f"Reading node properties: {node1_id}")
node1_props = await storage.get_node(node1_id)
if node1_props:
print(f"Successfully read node properties: {node1_id}")
print(f"Node description: {node1_props.get('description', 'No description')}")
print(f"Node type: {node1_props.get('entity_type', 'No type')}")
print(f"Node keywords: {node1_props.get('keywords', 'No keywords')}")
# Verify that the returned properties are correct
assert (
node1_props.get("entity_id") == node1_id
), f"Node ID mismatch: expected {node1_id}, got {node1_props.get('entity_id')}"
assert (
node1_props.get("description") == node1_data["description"]
), "Node description mismatch"
assert (
node1_props.get("entity_type") == node1_data["entity_type"]
), "Node type mismatch"
else:
print(f"Failed to read node properties: {node1_id}")
assert False, f"Failed to read node properties: {node1_id}"
# 5. Read edge properties
print(f"Reading edge properties: {node1_id} -> {node2_id}")
edge_props = await storage.get_edge(node1_id, node2_id)
if edge_props:
print(f"Successfully read edge properties: {node1_id} -> {node2_id}")
print(f"Edge relationship: {edge_props.get('relationship', 'No relationship')}")
print(f"Edge description: {edge_props.get('description', 'No description')}")
print(f"Edge weight: {edge_props.get('weight', 'No weight')}")
# Verify that the returned properties are correct
assert (
edge_props.get("relationship") == edge_data["relationship"]
), "Edge relationship mismatch"
assert (
edge_props.get("description") == edge_data["description"]
), "Edge description mismatch"
assert edge_props.get("weight") == edge_data["weight"], "Edge weight mismatch"
else:
print(f"Failed to read edge properties: {node1_id} -> {node2_id}")
assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}"
# 5.1 Verify undirected graph property - read reverse edge properties
print(f"Reading reverse edge properties: {node2_id} -> {node1_id}")
reverse_edge_props = await storage.get_edge(node2_id, node1_id)
if reverse_edge_props:
print(f"Successfully read reverse edge properties: {node2_id} -> {node1_id}")
print(f"Reverse edge relationship: {reverse_edge_props.get('relationship', 'No relationship')}")
print(f"Reverse edge description: {reverse_edge_props.get('description', 'No description')}")
print(f"Reverse edge weight: {reverse_edge_props.get('weight', 'No weight')}")
# Verify that forward and reverse edge properties are the same
assert (
edge_props == reverse_edge_props
), "Forward and reverse edge properties are not consistent, undirected graph property verification failed"
print("Undirected graph property verification successful: forward and reverse edge properties are consistent")
else:
print(f"Failed to read reverse edge properties: {node2_id} -> {node1_id}")
assert (
False
), f"Failed to read reverse edge properties: {node2_id} -> {node1_id}, undirected graph property verification failed"
print("Basic tests completed, data is preserved in the database.")
return True
except Exception as e:
ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_advanced(storage):
"""
Test advanced graph database operations:
1. Use node_degree to get the degree of a node.
2. Use edge_degree to get the degree of an edge.
3. Use get_node_edges to get all edges of a node.
4. Use get_all_labels to get all labels.
5. Use get_knowledge_graph to get a knowledge graph.
6. Use delete_node to delete a node.
7. Use remove_nodes to delete multiple nodes.
8. Use remove_edges to delete edges.
9. Use drop to clean up data.
"""
try:
# 1. Insert test data
# Insert node 1: Artificial Intelligence
node1_id = "Artificial Intelligence"
node1_data = {
"entity_id": node1_id,
"description": "Artificial intelligence is a branch of computer science that aims to understand the essence of intelligence and produce a new kind of intelligent machine that can react in a manner similar to human intelligence.",
"keywords": "AI,Machine Learning,Deep Learning",
"entity_type": "Technology Field",
}
print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# Insert node 2: Machine Learning
node2_id = "Machine Learning"
node2_data = {
"entity_id": node2_id,
"description": "Machine learning is a branch of artificial intelligence that uses statistical methods to enable computer systems to learn without being explicitly programmed.",
"keywords": "Supervised Learning,Unsupervised Learning,Reinforcement Learning",
"entity_type": "Technology Field",
}
print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# Insert node 3: Deep Learning
node3_id = "Deep Learning"
node3_data = {
"entity_id": node3_id,
"description": "Deep learning is a branch of machine learning that uses multi-layered neural networks to simulate the learning process of the human brain.",
"keywords": "Neural Networks,CNN,RNN",
"entity_type": "Technology Field",
}
print(f"Inserting node 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
# Insert edge 1: Artificial Intelligence -> Machine Learning
edge1_data = {
"relationship": "includes",
"weight": 1.0,
"description": "The field of artificial intelligence includes the subfield of machine learning.",
}
print(f"Inserting edge 1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# Insert edge 2: Machine Learning -> Deep Learning
edge2_data = {
"relationship": "includes",
"weight": 1.0,
"description": "The field of machine learning includes the subfield of deep learning.",
}
print(f"Inserting edge 2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
# 2. Test node_degree - get the degree of a node
print(f"== Testing node_degree: {node1_id}")
node1_degree = await storage.node_degree(node1_id)
print(f"Degree of node {node1_id}: {node1_degree}")
assert node1_degree == 1, f"Degree of node {node1_id} should be 1, but got {node1_degree}"
# 2.1 Test degrees of all nodes
print("== Testing degrees of all nodes")
node2_degree = await storage.node_degree(node2_id)
node3_degree = await storage.node_degree(node3_id)
print(f"Degree of node {node2_id}: {node2_degree}")
print(f"Degree of node {node3_id}: {node3_degree}")
assert node2_degree == 2, f"Degree of node {node2_id} should be 2, but got {node2_degree}"
assert node3_degree == 1, f"Degree of node {node3_id} should be 1, but got {node3_degree}"
# 3. Test edge_degree - get the degree of an edge
print(f"== Testing edge_degree: {node1_id} -> {node2_id}")
edge_degree = await storage.edge_degree(node1_id, node2_id)
print(f"Degree of edge {node1_id} -> {node2_id}: {edge_degree}")
assert (
edge_degree == 3
), f"Degree of edge {node1_id} -> {node2_id} should be 3, but got {edge_degree}"
# 3.1 Test reverse edge degree - verify undirected graph property
print(f"== Testing reverse edge degree: {node2_id} -> {node1_id}")
reverse_edge_degree = await storage.edge_degree(node2_id, node1_id)
print(f"Degree of reverse edge {node2_id} -> {node1_id}: {reverse_edge_degree}")
assert (
edge_degree == reverse_edge_degree
), "Degrees of forward and reverse edges are not consistent, undirected graph property verification failed"
print("Undirected graph property verification successful: degrees of forward and reverse edges are consistent")
# 4. Test get_node_edges - get all edges of a node
print(f"== Testing get_node_edges: {node2_id}")
node2_edges = await storage.get_node_edges(node2_id)
print(f"All edges of node {node2_id}: {node2_edges}")
assert (
len(node2_edges) == 2
), f"Node {node2_id} should have 2 edges, but got {len(node2_edges)}"
# 4.1 Verify undirected graph property of node edges
print("== Verifying undirected graph property of node edges")
# Check if it includes connections with node1 and node3 (regardless of direction)
has_connection_with_node1 = False
has_connection_with_node3 = False
for edge in node2_edges:
# Check for connection with node1 (regardless of direction)
if (edge[0] == node1_id and edge[1] == node2_id) or (
edge[0] == node2_id and edge[1] == node1_id
):
has_connection_with_node1 = True
# Check for connection with node3 (regardless of direction)
if (edge[0] == node2_id and edge[1] == node3_id) or (
edge[0] == node3_id and edge[1] == node2_id
):
has_connection_with_node3 = True
assert (
has_connection_with_node1
), f"Edge list of node {node2_id} should include a connection with {node1_id}"
assert (
has_connection_with_node3
), f"Edge list of node {node2_id} should include a connection with {node3_id}"
print(f"Undirected graph property verification successful: edge list of node {node2_id} contains all relevant edges")
# 5. Test get_all_labels - get all labels
print("== Testing get_all_labels")
all_labels = await storage.get_all_labels()
print(f"All labels: {all_labels}")
assert len(all_labels) == 3, f"Should have 3 labels, but got {len(all_labels)}"
assert node1_id in all_labels, f"{node1_id} should be in the label list"
assert node2_id in all_labels, f"{node2_id} should be in the label list"
assert node3_id in all_labels, f"{node3_id} should be in the label list"
# 6. Test get_knowledge_graph - get a knowledge graph
print("== Testing get_knowledge_graph")
kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
print(f"Number of nodes in knowledge graph: {len(kg.nodes)}")
print(f"Number of edges in knowledge graph: {len(kg.edges)}")
assert isinstance(kg, KnowledgeGraph), "The returned result should be of type KnowledgeGraph"
assert len(kg.nodes) == 3, f"The knowledge graph should have 3 nodes, but got {len(kg.nodes)}"
assert len(kg.edges) == 2, f"The knowledge graph should have 2 edges, but got {len(kg.edges)}"
# 7. Test delete_node - delete a node
print(f"== Testing delete_node: {node3_id}")
await storage.delete_node(node3_id)
node3_props = await storage.get_node(node3_id)
print(f"Querying node properties after deletion {node3_id}: {node3_props}")
assert node3_props is None, f"Node {node3_id} should have been deleted"
# Re-insert node 3 for subsequent tests
await storage.upsert_node(node3_id, node3_data)
await storage.upsert_edge(node2_id, node3_id, edge2_data)
# 8. Test remove_edges - delete edges
print(f"== Testing remove_edges: {node2_id} -> {node3_id}")
await storage.remove_edges([(node2_id, node3_id)])
edge_props = await storage.get_edge(node2_id, node3_id)
print(f"Querying edge properties after deletion {node2_id} -> {node3_id}: {edge_props}")
assert edge_props is None, f"Edge {node2_id} -> {node3_id} should have been deleted"
# 8.1 Verify undirected graph property of edge deletion
print(f"== Verifying undirected graph property of edge deletion: {node3_id} -> {node2_id}")
reverse_edge_props = await storage.get_edge(node3_id, node2_id)
print(f"Querying reverse edge properties after deletion {node3_id} -> {node2_id}: {reverse_edge_props}")
assert (
reverse_edge_props is None
), f"Reverse edge {node3_id} -> {node2_id} should also be deleted, undirected graph property verification failed"
print("Undirected graph property verification successful: deleting an edge in one direction also deletes the reverse edge")
# 9. Test remove_nodes - delete multiple nodes
print(f"== Testing remove_nodes: [{node2_id}, {node3_id}]")
await storage.remove_nodes([node2_id, node3_id])
node2_props = await storage.get_node(node2_id)
node3_props = await storage.get_node(node3_id)
print(f"Querying node properties after deletion {node2_id}: {node2_props}")
print(f"Querying node properties after deletion {node3_id}: {node3_props}")
assert node2_props is None, f"Node {node2_id} should have been deleted"
assert node3_props is None, f"Node {node3_id} should have been deleted"
print("\nAdvanced tests completed.")
return True
except Exception as e:
ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_batch_operations(storage):
"""
Test batch operations of the graph database:
1. Use get_nodes_batch to get properties of multiple nodes in batch.
2. Use node_degrees_batch to get degrees of multiple nodes in batch.
3. Use edge_degrees_batch to get degrees of multiple edges in batch.
4. Use get_edges_batch to get properties of multiple edges in batch.
5. Use get_nodes_edges_batch to get all edges of multiple nodes in batch.
"""
try:
chunk1_id = "1"
chunk2_id = "2"
chunk3_id = "3"
# 1. Insert test data
# Insert node 1: Artificial Intelligence
node1_id = "Artificial Intelligence"
node1_data = {
"entity_id": node1_id,
"description": "Artificial intelligence is a branch of computer science that aims to understand the essence of intelligence and produce a new kind of intelligent machine that can react in a manner similar to human intelligence.",
"keywords": "AI,Machine Learning,Deep Learning",
"entity_type": "Technology Field",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# Insert node 2: Machine Learning
node2_id = "Machine Learning"
node2_data = {
"entity_id": node2_id,
"description": "Machine learning is a branch of artificial intelligence that uses statistical methods to enable computer systems to learn without being explicitly programmed.",
"keywords": "Supervised Learning,Unsupervised Learning,Reinforcement Learning",
"entity_type": "Technology Field",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# Insert node 3: Deep Learning
node3_id = "Deep Learning"
node3_data = {
"entity_id": node3_id,
"description": "Deep learning is a branch of machine learning that uses multi-layered neural networks to simulate the learning process of the human brain.",
"keywords": "Neural Networks,CNN,RNN",
"entity_type": "Technology Field",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
print(f"Inserting node 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
# Insert node 4: Natural Language Processing
node4_id = "Natural Language Processing"
node4_data = {
"entity_id": node4_id,
"description": "Natural language processing is a branch of artificial intelligence that focuses on enabling computers to understand and process human language.",
"keywords": "NLP,Text Analysis,Language Models",
"entity_type": "Technology Field",
}
print(f"Inserting node 4: {node4_id}")
await storage.upsert_node(node4_id, node4_data)
# Insert node 5: Computer Vision
node5_id = "Computer Vision"
node5_data = {
"entity_id": node5_id,
"description": "Computer vision is a branch of artificial intelligence that focuses on enabling computers to gain information from images or videos.",
"keywords": "CV,Image Recognition,Object Detection",
"entity_type": "Technology Field",
}
print(f"Inserting node 5: {node5_id}")
await storage.upsert_node(node5_id, node5_data)
# Insert edge 1: Artificial Intelligence -> Machine Learning
edge1_data = {
"relationship": "includes",
"weight": 1.0,
"description": "The field of artificial intelligence includes the subfield of machine learning.",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
print(f"Inserting edge 1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# Insert edge 2: Machine Learning -> Deep Learning
edge2_data = {
"relationship": "includes",
"weight": 1.0,
"description": "The field of machine learning includes the subfield of deep learning.",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
print(f"Inserting edge 2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
# Insert edge 3: Artificial Intelligence -> Natural Language Processing
edge3_data = {
"relationship": "includes",
"weight": 1.0,
"description": "The field of artificial intelligence includes the subfield of natural language processing.",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
print(f"Inserting edge 3: {node1_id} -> {node4_id}")
await storage.upsert_edge(node1_id, node4_id, edge3_data)
# Insert edge 4: Artificial Intelligence -> Computer Vision
edge4_data = {
"relationship": "includes",
"weight": 1.0,
"description": "The field of artificial intelligence includes the subfield of computer vision.",
}
print(f"Inserting edge 4: {node1_id} -> {node5_id}")
await storage.upsert_edge(node1_id, node5_id, edge4_data)
# Insert edge 5: Deep Learning -> Natural Language Processing
edge5_data = {
"relationship": "applied to",
"weight": 0.8,
"description": "Deep learning techniques are applied in the field of natural language processing.",
}
print(f"Inserting edge 5: {node3_id} -> {node4_id}")
await storage.upsert_edge(node3_id, node4_id, edge5_data)
# Insert edge 6: Deep Learning -> Computer Vision
edge6_data = {
"relationship": "applied to",
"weight": 0.8,
"description": "Deep learning techniques are applied in the field of computer vision.",
}
print(f"Inserting edge 6: {node3_id} -> {node5_id}")
await storage.upsert_edge(node3_id, node5_id, edge6_data)
# 2. Test get_nodes_batch - batch get properties of multiple nodes
print("== Testing get_nodes_batch")
node_ids = [node1_id, node2_id, node3_id]
nodes_dict = await storage.get_nodes_batch(node_ids)
print(f"Batch get node properties result: {nodes_dict.keys()}")
assert len(nodes_dict) == 3, f"Should return 3 nodes, but got {len(nodes_dict)}"
assert node1_id in nodes_dict, f"{node1_id} should be in the result"
assert node2_id in nodes_dict, f"{node2_id} should be in the result"
assert node3_id in nodes_dict, f"{node3_id} should be in the result"
assert (
nodes_dict[node1_id]["description"] == node1_data["description"]
), f"{node1_id} description mismatch"
assert (
nodes_dict[node2_id]["description"] == node2_data["description"]
), f"{node2_id} description mismatch"
assert (
nodes_dict[node3_id]["description"] == node3_data["description"]
), f"{node3_id} description mismatch"
# 3. Test node_degrees_batch - batch get degrees of multiple nodes
print("== Testing node_degrees_batch")
node_degrees = await storage.node_degrees_batch(node_ids)
print(f"Batch get node degrees result: {node_degrees}")
assert (
len(node_degrees) == 3
), f"Should return degrees of 3 nodes, but got {len(node_degrees)}"
assert node1_id in node_degrees, f"{node1_id} should be in the result"
assert node2_id in node_degrees, f"{node2_id} should be in the result"
assert node3_id in node_degrees, f"{node3_id} should be in the result"
assert (
node_degrees[node1_id] == 3
), f"Degree of {node1_id} should be 3, but got {node_degrees[node1_id]}"
assert (
node_degrees[node2_id] == 2
), f"Degree of {node2_id} should be 2, but got {node_degrees[node2_id]}"
assert (
node_degrees[node3_id] == 3
), f"Degree of {node3_id} should be 3, but got {node_degrees[node3_id]}"
# 4. Test edge_degrees_batch - batch get degrees of multiple edges
print("== Testing edge_degrees_batch")
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
edge_degrees = await storage.edge_degrees_batch(edges)
print(f"Batch get edge degrees result: {edge_degrees}")
assert (
len(edge_degrees) == 3
), f"Should return degrees of 3 edges, but got {len(edge_degrees)}"
assert (
node1_id,
node2_id,
) in edge_degrees, f"Edge {node1_id} -> {node2_id} should be in the result"
assert (
node2_id,
node3_id,
) in edge_degrees, f"Edge {node2_id} -> {node3_id} should be in the result"
assert (
node3_id,
node4_id,
) in edge_degrees, f"Edge {node3_id} -> {node4_id} should be in the result"
# Verify edge degrees (sum of source and target node degrees)
assert (
edge_degrees[(node1_id, node2_id)] == 5
), f"Degree of edge {node1_id} -> {node2_id} should be 5, but got {edge_degrees[(node1_id, node2_id)]}"
assert (
edge_degrees[(node2_id, node3_id)] == 5
), f"Degree of edge {node2_id} -> {node3_id} should be 5, but got {edge_degrees[(node2_id, node3_id)]}"
assert (
edge_degrees[(node3_id, node4_id)] == 5
), f"Degree of edge {node3_id} -> {node4_id} should be 5, but got {edge_degrees[(node3_id, node4_id)]}"
# 5. Test get_edges_batch - batch get properties of multiple edges
print("== Testing get_edges_batch")
# Convert list of tuples to list of dicts for Neo4j style
edge_dicts = [{"src": src, "tgt": tgt} for src, tgt in edges]
edges_dict = await storage.get_edges_batch(edge_dicts)
print(f"Batch get edge properties result: {edges_dict.keys()}")
assert len(edges_dict) == 3, f"Should return properties of 3 edges, but got {len(edges_dict)}"
assert (
node1_id,
node2_id,
) in edges_dict, f"Edge {node1_id} -> {node2_id} should be in the result"
assert (
node2_id,
node3_id,
) in edges_dict, f"Edge {node2_id} -> {node3_id} should be in the result"
assert (
node3_id,
node4_id,
) in edges_dict, f"Edge {node3_id} -> {node4_id} should be in the result"
assert (
edges_dict[(node1_id, node2_id)]["relationship"]
== edge1_data["relationship"]
), f"Edge {node1_id} -> {node2_id} relationship mismatch"
assert (
edges_dict[(node2_id, node3_id)]["relationship"]
== edge2_data["relationship"]
), f"Edge {node2_id} -> {node3_id} relationship mismatch"
assert (
edges_dict[(node3_id, node4_id)]["relationship"]
== edge5_data["relationship"]
), f"Edge {node3_id} -> {node4_id} relationship mismatch"
# 5.1 Test batch get of reverse edges - verify undirected property
print("== Testing batch get of reverse edges")
# Create list of dicts for reverse edges
reverse_edge_dicts = [{"src": tgt, "tgt": src} for src, tgt in edges]
reverse_edges_dict = await storage.get_edges_batch(reverse_edge_dicts)
print(f"Batch get reverse edge properties result: {reverse_edges_dict.keys()}")
assert (
len(reverse_edges_dict) == 3
), f"Should return properties of 3 reverse edges, but got {len(reverse_edges_dict)}"
# Verify that properties of forward and reverse edges are consistent
for (src, tgt), props in edges_dict.items():
assert (
tgt,
src,
) in reverse_edges_dict, f"Reverse edge {tgt} -> {src} should be in the result"
assert (
props == reverse_edges_dict[(tgt, src)]
), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent"
print("Undirected graph property verification successful: properties of batch-retrieved forward and reverse edges are consistent")
# 6. Test get_nodes_edges_batch - batch get all edges of multiple nodes
print("== Testing get_nodes_edges_batch")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
print(f"Batch get node edges result: {nodes_edges.keys()}")
assert (
len(nodes_edges) == 2
), f"Should return edges for 2 nodes, but got {len(nodes_edges)}"
assert node1_id in nodes_edges, f"{node1_id} should be in the result"
assert node3_id in nodes_edges, f"{node3_id} should be in the result"
assert (
len(nodes_edges[node1_id]) == 3
), f"{node1_id} should have 3 edges, but has {len(nodes_edges[node1_id])}"
assert (
len(nodes_edges[node3_id]) == 3
), f"{node3_id} should have 3 edges, but has {len(nodes_edges[node3_id])}"
# 6.1 Verify undirected property of batch-retrieved node edges
print("== Verifying undirected property of batch-retrieved node edges")
# Check if node 1's edges include all relevant edges (regardless of direction)
node1_outgoing_edges = [
(src, tgt) for src, tgt in nodes_edges[node1_id] if src == node1_id
]
node1_incoming_edges = [
(src, tgt) for src, tgt in nodes_edges[node1_id] if tgt == node1_id
]
print(f"Outgoing edges of node {node1_id}: {node1_outgoing_edges}")
print(f"Incoming edges of node {node1_id}: {node1_incoming_edges}")
# Check for edges to Machine Learning, Natural Language Processing, and Computer Vision
has_edge_to_node2 = any(tgt == node2_id for _, tgt in node1_outgoing_edges)
has_edge_to_node4 = any(tgt == node4_id for _, tgt in node1_outgoing_edges)
has_edge_to_node5 = any(tgt == node5_id for _, tgt in node1_outgoing_edges)
assert has_edge_to_node2, f"Edge list of node {node1_id} should include an edge to {node2_id}"
assert has_edge_to_node4, f"Edge list of node {node1_id} should include an edge to {node4_id}"
assert has_edge_to_node5, f"Edge list of node {node1_id} should include an edge to {node5_id}"
# Check if node 3's edges include all relevant edges (regardless of direction)
node3_outgoing_edges = [
(src, tgt) for src, tgt in nodes_edges[node3_id] if src == node3_id
]
node3_incoming_edges = [
(src, tgt) for src, tgt in nodes_edges[node3_id] if tgt == node3_id
]
print(f"Outgoing edges of node {node3_id}: {node3_outgoing_edges}")
print(f"Incoming edges of node {node3_id}: {node3_incoming_edges}")
# Check for connections with Machine Learning, Natural Language Processing, and Computer Vision (ignoring direction)
has_connection_with_node2 = any(
(src == node2_id and tgt == node3_id)
or (src == node3_id and tgt == node2_id)
for src, tgt in nodes_edges[node3_id]
)
has_connection_with_node4 = any(
(src == node3_id and tgt == node4_id)
or (src == node4_id and tgt == node3_id)
for src, tgt in nodes_edges[node3_id]
)
has_connection_with_node5 = any(
(src == node3_id and tgt == node5_id)
or (src == node5_id and tgt == node3_id)
for src, tgt in nodes_edges[node3_id]
)
assert (
has_connection_with_node2
), f"Edge list of node {node3_id} should include a connection with {node2_id}"
assert (
has_connection_with_node4
), f"Edge list of node {node3_id} should include a connection with {node4_id}"
assert (
has_connection_with_node5
), f"Edge list of node {node3_id} should include a connection with {node5_id}"
print("Undirected graph property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)")
print("\nBatch operations tests completed.")
return True
except Exception as e:
ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_special_characters(storage):
"""
Test the graph database's handling of special characters:
1. Test node names and descriptions containing single quotes, double quotes, and backslashes.
2. Test edge descriptions containing single quotes, double quotes, and backslashes.
3. Verify that special characters are saved and retrieved correctly.
"""
try:
# 1. Test special characters in node name
node1_id = "Node with 'single quotes'"
node1_data = {
"entity_id": node1_id,
"description": "This description contains 'single quotes', \"double quotes\", and \\backslashes",
"keywords": "special characters,quotes,escaping",
"entity_type": "Test Node",
}
print(f"Inserting node with special characters 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# 2. Test double quotes in node name
node2_id = 'Node with "double quotes"'
node2_data = {
"entity_id": node2_id,
"description": "This description contains both 'single quotes' and \"double quotes\" and \\a\\path",
"keywords": "special characters,quotes,JSON",
"entity_type": "Test Node",
}
print(f"Inserting node with special characters 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# 3. Test backslashes in node name
node3_id = "Node with \\backslashes\\"
node3_data = {
"entity_id": node3_id,
"description": "This description contains a Windows path C:\\Program Files\\ and escape characters \\n\\t",
"keywords": "backslashes,paths,escaping",
"entity_type": "Test Node",
}
print(f"Inserting node with special characters 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
# 4. Test special characters in edge description
edge1_data = {
"relationship": "special 'relationship'",
"weight": 1.0,
"description": "This edge description contains 'single quotes', \"double quotes\", and \\backslashes",
}
print(f"Inserting edge with special characters: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# 5. Test more complex combination of special characters in edge description
edge2_data = {
"relationship": 'complex "relationship"\\type',
"weight": 0.8,
"description": "Contains SQL injection attempt: SELECT * FROM users WHERE name='admin'--",
}
print(f"Inserting edge with complex special characters: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
# 6. Verify that node special characters are saved correctly
print("\n== Verifying node special characters")
for node_id, original_data in [
(node1_id, node1_data),
(node2_id, node2_data),
(node3_id, node3_data),
]:
node_props = await storage.get_node(node_id)
if node_props:
print(f"Successfully read node: {node_id}")
print(f"Node description: {node_props.get('description', 'No description')}")
# Verify node ID is saved correctly
assert (
node_props.get("entity_id") == node_id
), f"Node ID mismatch: expected {node_id}, got {node_props.get('entity_id')}"
# Verify description is saved correctly
assert (
node_props.get("description") == original_data["description"]
), f"Node description mismatch: expected {original_data['description']}, got {node_props.get('description')}"
print(f"Node {node_id} special character verification successful")
else:
print(f"Failed to read node properties: {node_id}")
assert False, f"Failed to read node properties: {node_id}"
# 7. Verify that edge special characters are saved correctly
print("\n== Verifying edge special characters")
edge1_props = await storage.get_edge(node1_id, node2_id)
if edge1_props:
print(f"Successfully read edge: {node1_id} -> {node2_id}")
print(f"Edge relationship: {edge1_props.get('relationship', 'No relationship')}")
print(f"Edge description: {edge1_props.get('description', 'No description')}")
# Verify edge relationship is saved correctly
assert (
edge1_props.get("relationship") == edge1_data["relationship"]
), f"Edge relationship mismatch: expected {edge1_data['relationship']}, got {edge1_props.get('relationship')}"
# Verify edge description is saved correctly
assert (
edge1_props.get("description") == edge1_data["description"]
), f"Edge description mismatch: expected {edge1_data['description']}, got {edge1_props.get('description')}"
print(f"Edge {node1_id} -> {node2_id} special character verification successful")
else:
print(f"Failed to read edge properties: {node1_id} -> {node2_id}")
assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}"
edge2_props = await storage.get_edge(node2_id, node3_id)
if edge2_props:
print(f"Successfully read edge: {node2_id} -> {node3_id}")
print(f"Edge relationship: {edge2_props.get('relationship', 'No relationship')}")
print(f"Edge description: {edge2_props.get('description', 'No description')}")
# Verify edge relationship is saved correctly
assert (
edge2_props.get("relationship") == edge2_data["relationship"]
), f"Edge relationship mismatch: expected {edge2_data['relationship']}, got {edge2_props.get('relationship')}"
# Verify edge description is saved correctly
assert (
edge2_props.get("description") == edge2_data["description"]
), f"Edge description mismatch: expected {edge2_data['description']}, got {edge2_props.get('description')}"
print(f"Edge {node2_id} -> {node3_id} special character verification successful")
else:
print(f"Failed to read edge properties: {node2_id} -> {node3_id}")
assert False, f"Failed to read edge properties: {node2_id} -> {node3_id}"
print("\nSpecial character tests completed, data is preserved in the database.")
return True
except Exception as e:
ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def test_graph_undirected_property(storage):
"""
Specifically test the undirected graph property of the storage:
1. Verify that after inserting an edge in one direction, a reverse query can retrieve the same result.
2. Verify that edge properties are consistent in forward and reverse queries.
3. Verify that after deleting an edge in one direction, the edge in the other direction is also deleted.
4. Verify the undirected property in batch operations.
"""
try:
# 1. Insert test data
# Insert node 1: Computer Science
node1_id = "Computer Science"
node1_data = {
"entity_id": node1_id,
"description": "Computer science is the study of computers and their applications.",
"keywords": "computer,science,technology",
"entity_type": "Discipline",
}
print(f"Inserting node 1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# Insert node 2: Data Structures
node2_id = "Data Structures"
node2_data = {
"entity_id": node2_id,
"description": "A data structure is a fundamental concept in computer science used to organize and store data.",
"keywords": "data,structure,organization",
"entity_type": "Concept",
}
print(f"Inserting node 2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# Insert node 3: Algorithms
node3_id = "Algorithms"
node3_data = {
"entity_id": node3_id,
"description": "An algorithm is a set of steps and methods for solving problems.",
"keywords": "algorithm,steps,methods",
"entity_type": "Concept",
}
print(f"Inserting node 3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
# 2. Test undirected property after edge insertion
print("\n== Testing undirected property after edge insertion")
# Insert edge 1: Computer Science -> Data Structures
edge1_data = {
"relationship": "includes",
"weight": 1.0,
"description": "Computer science includes the concept of data structures.",
}
print(f"Inserting edge 1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# Verify forward query
forward_edge = await storage.get_edge(node1_id, node2_id)
print(f"Forward edge properties: {forward_edge}")
assert forward_edge is not None, f"Failed to read forward edge properties: {node1_id} -> {node2_id}"
# Verify reverse query
reverse_edge = await storage.get_edge(node2_id, node1_id)
print(f"Reverse edge properties: {reverse_edge}")
assert reverse_edge is not None, f"Failed to read reverse edge properties: {node2_id} -> {node1_id}"
# Verify that forward and reverse edge properties are consistent
assert (
forward_edge == reverse_edge
), "Forward and reverse edge properties are inconsistent, undirected property verification failed"
print("Undirected property verification successful: forward and reverse edge properties are consistent")
# 3. Test undirected property of edge degree
print("\n== Testing undirected property of edge degree")
# Insert edge 2: Computer Science -> Algorithms
edge2_data = {
"relationship": "includes",
"weight": 1.0,
"description": "Computer science includes the concept of algorithms.",
}
print(f"Inserting edge 2: {node1_id} -> {node3_id}")
await storage.upsert_edge(node1_id, node3_id, edge2_data)
# Verify degrees of forward and reverse edges
forward_degree = await storage.edge_degree(node1_id, node2_id)
reverse_degree = await storage.edge_degree(node2_id, node1_id)
print(f"Degree of forward edge {node1_id} -> {node2_id}: {forward_degree}")
print(f"Degree of reverse edge {node2_id} -> {node1_id}: {reverse_degree}")
assert (
forward_degree == reverse_degree
), "Degrees of forward and reverse edges are inconsistent, undirected property verification failed"
print("Undirected property verification successful: degrees of forward and reverse edges are consistent")
# 4. Test undirected property of edge deletion
print("\n== Testing undirected property of edge deletion")
# Delete forward edge
print(f"Deleting edge: {node1_id} -> {node2_id}")
await storage.remove_edges([(node1_id, node2_id)])
# Verify forward edge is deleted
forward_edge = await storage.get_edge(node1_id, node2_id)
print(f"Querying forward edge properties after deletion {node1_id} -> {node2_id}: {forward_edge}")
assert forward_edge is None, f"Edge {node1_id} -> {node2_id} should have been deleted"
# Verify reverse edge is also deleted
reverse_edge = await storage.get_edge(node2_id, node1_id)
print(f"Querying reverse edge properties after deletion {node2_id} -> {node1_id}: {reverse_edge}")
assert (
reverse_edge is None
), f"Reverse edge {node2_id} -> {node1_id} should also be deleted, undirected property verification failed"
print("Undirected property verification successful: deleting an edge in one direction also deletes the reverse edge")
# 5. Test undirected property in batch operations
print("\n== Testing undirected property in batch operations")
# Re-insert edge
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# Batch get edge properties
edge_dicts = [
{"src": node1_id, "tgt": node2_id},
{"src": node1_id, "tgt": node3_id},
]
reverse_edge_dicts = [
{"src": node2_id, "tgt": node1_id},
{"src": node3_id, "tgt": node1_id},
]
edges_dict = await storage.get_edges_batch(edge_dicts)
reverse_edges_dict = await storage.get_edges_batch(reverse_edge_dicts)
print(f"Batch get forward edge properties result: {edges_dict.keys()}")
print(f"Batch get reverse edge properties result: {reverse_edges_dict.keys()}")
# Verify that properties of forward and reverse edges are consistent
for (src, tgt), props in edges_dict.items():
assert (
tgt,
src,
) in reverse_edges_dict, f"Reverse edge {tgt} -> {src} should be in the result"
assert (
props == reverse_edges_dict[(tgt, src)]
), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent"
print("Undirected property verification successful: properties of batch-retrieved forward and reverse edges are consistent")
# 6. Test undirected property of batch-retrieved node edges
print("\n== Testing undirected property of batch-retrieved node edges")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node2_id])
print(f"Batch get node edges result: {nodes_edges.keys()}")
# Check if node 1's edges include all relevant edges (regardless of direction)
node1_edges = nodes_edges[node1_id]
node2_edges = nodes_edges[node2_id]
# Check if node 1 has edges to node 2 and node 3
has_edge_to_node2 = any(
(src == node1_id and tgt == node2_id) for src, tgt in node1_edges
)
has_edge_to_node3 = any(
(src == node1_id and tgt == node3_id) for src, tgt in node1_edges
)
assert has_edge_to_node2, f"Edge list of node {node1_id} should include an edge to {node2_id}"
assert has_edge_to_node3, f"Edge list of node {node1_id} should include an edge to {node3_id}"
# Check if node 2 has a connection with node 1
has_edge_to_node1 = any(
(src == node2_id and tgt == node1_id)
or (src == node1_id and tgt == node2_id)
for src, tgt in node2_edges
)
assert (
has_edge_to_node1
), f"Edge list of node {node2_id} should include a connection with {node1_id}"
print("Undirected property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)")
print("\nUndirected property tests completed.")
return True
except Exception as e:
ASCIIColors.red(f"An error occurred during the test: {str(e)}")
return False
async def main():
"""Main function"""
# Display program title
ASCIIColors.cyan("""
╔══════════════════════════════════════════════════════════════╗
║ General Graph Storage Test Program ║
╚══════════════════════════════════════════════════════════════╝
""")
# Check for .env file
if not check_env_file():
return
# Load environment variables
load_dotenv(dotenv_path=".env", override=False)
# Get graph storage type
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
ASCIIColors.magenta(f"\nCurrently configured graph storage type: {graph_storage_type}")
ASCIIColors.white(
f"Supported graph storage types: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
# Initialize storage instance
storage = await initialize_graph_storage()
if not storage:
ASCIIColors.red("Failed to initialize storage instance, exiting test program.")
return
try:
# Display test options
ASCIIColors.yellow("\nPlease select a test type:")
ASCIIColors.white("1. Basic Test (Node and edge insertion, reading)")
ASCIIColors.white("2. Advanced Test (Degree, labels, knowledge graph, deletion, etc.)")
ASCIIColors.white("3. Batch Operations Test (Batch get node/edge properties, degrees, etc.)")
ASCIIColors.white("4. Undirected Property Test (Verify undirected properties of the storage)")
ASCIIColors.white("5. Special Characters Test (Verify handling of single/double quotes, backslashes, etc.)")
ASCIIColors.white("6. All Tests")
choice = input("\nEnter your choice (1/2/3/4/5/6): ")
# Clean data before running tests
if choice in ["1", "2", "3", "4", "5", "6"]:
ASCIIColors.yellow("\nCleaning data before running tests...")
await storage.drop()
ASCIIColors.green("Data cleanup complete\n")
if choice == "1":
await test_graph_basic(storage)
elif choice == "2":
await test_graph_advanced(storage)
elif choice == "3":
await test_graph_batch_operations(storage)
elif choice == "4":
await test_graph_undirected_property(storage)
elif choice == "5":
await test_graph_special_characters(storage)
elif choice == "6":
ASCIIColors.cyan("\n=== Starting Basic Test ===")
basic_result = await test_graph_basic(storage)
if basic_result:
ASCIIColors.cyan("\n=== Starting Advanced Test ===")
advanced_result = await test_graph_advanced(storage)
if advanced_result:
ASCIIColors.cyan("\n=== Starting Batch Operations Test ===")
batch_result = await test_graph_batch_operations(storage)
if batch_result:
ASCIIColors.cyan("\n=== Starting Undirected Property Test ===")
undirected_result = await test_graph_undirected_property(
storage
)
if undirected_result:
ASCIIColors.cyan("\n=== Starting Special Characters Test ===")
await test_graph_special_characters(storage)
else:
ASCIIColors.red("Invalid choice")
finally:
# Close connection
if storage:
await storage.finalize()
ASCIIColors.green("\nStorage connection closed.")
if __name__ == "__main__":
asyncio.run(main())