Merge branch 'kenspirit/main'

This commit is contained in:
yangdx 2025-06-26 13:52:42 +08:00
commit 778ad4f23a
3 changed files with 460 additions and 404 deletions

View file

@ -14,8 +14,8 @@ STORAGE_IMPLEMENTATIONS = {
"NetworkXStorage",
"Neo4JStorage",
"PGGraphStorage",
"MongoGraphStorage",
# "AGEStorage",
# "MongoGraphStorage",
# "TiDBGraphStorage",
# "GremlinStorage",
],

File diff suppressed because it is too large Load diff

View file

@ -30,6 +30,7 @@ from lightrag.kg import (
verify_storage_implementation,
)
from lightrag.kg.shared_storage import initialize_share_data
from lightrag.constants import GRAPH_FIELD_SEP
# 模拟的嵌入函数,返回随机向量
@ -437,6 +438,9 @@ async def test_graph_batch_operations(storage):
5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
"""
try:
chunk1_id = "1"
chunk2_id = "2"
chunk3_id = "3"
# 1. 插入测试数据
# 插入节点1: 人工智能
node1_id = "人工智能"
@ -445,6 +449,7 @@ async def test_graph_batch_operations(storage):
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
@ -456,6 +461,7 @@ async def test_graph_batch_operations(storage):
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
@ -467,6 +473,7 @@ async def test_graph_batch_operations(storage):
"description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
"keywords": "神经网络,CNN,RNN",
"entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
print(f"插入节点3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
@ -498,6 +505,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
}
print(f"插入边1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
@ -507,6 +515,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含",
"weight": 1.0,
"description": "机器学习领域包含深度学习这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
}
print(f"插入边2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
@ -516,6 +525,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含自然语言处理这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
}
print(f"插入边3: {node1_id} -> {node4_id}")
await storage.upsert_edge(node1_id, node4_id, edge3_data)
@ -748,6 +758,76 @@ async def test_graph_batch_operations(storage):
print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
# 7. 测试 get_nodes_by_chunk_ids - 批量根据 chunk_ids 获取多个节点
print("== 测试 get_nodes_by_chunk_ids")
print("== 测试单个 chunk_id匹配多个节点")
nodes = await storage.get_nodes_by_chunk_ids([chunk2_id])
assert len(nodes) == 2, f"{chunk1_id} 应有2个节点实际有 {len(nodes)}"
has_node1 = any(node["entity_id"] == node1_id for node in nodes)
has_node2 = any(node["entity_id"] == node2_id for node in nodes)
assert has_node1, f"节点 {node1_id} 应在返回结果中"
assert has_node2, f"节点 {node2_id} 应在返回结果中"
print("== 测试多个 chunk_id部分匹配多个节点")
nodes = await storage.get_nodes_by_chunk_ids([chunk2_id, chunk3_id])
assert (
len(nodes) == 3
), f"{chunk2_id}, {chunk3_id} 应有3个节点实际有 {len(nodes)}"
has_node1 = any(node["entity_id"] == node1_id for node in nodes)
has_node2 = any(node["entity_id"] == node2_id for node in nodes)
has_node3 = any(node["entity_id"] == node3_id for node in nodes)
assert has_node1, f"节点 {node1_id} 应在返回结果中"
assert has_node2, f"节点 {node2_id} 应在返回结果中"
assert has_node3, f"节点 {node3_id} 应在返回结果中"
# 8. 测试 get_edges_by_chunk_ids - 批量根据 chunk_ids 获取多条边
print("== 测试 get_edges_by_chunk_ids")
print("== 测试单个 chunk_id匹配多条边")
edges = await storage.get_edges_by_chunk_ids([chunk2_id])
assert len(edges) == 2, f"{chunk2_id} 应有2条边实际有 {len(edges)}"
has_edge_node1_node2 = any(
edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
)
has_edge_node2_node3 = any(
edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
)
assert has_edge_node1_node2, f"{chunk2_id} 应包含 {node1_id}{node2_id} 的边"
assert has_edge_node2_node3, f"{chunk2_id} 应包含 {node2_id}{node3_id} 的边"
print("== 测试多个 chunk_id部分匹配多条边")
edges = await storage.get_edges_by_chunk_ids([chunk2_id, chunk3_id])
assert (
len(edges) == 3
), f"{chunk2_id}, {chunk3_id} 应有3条边实际有 {len(edges)}"
has_edge_node1_node2 = any(
edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
)
has_edge_node2_node3 = any(
edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
)
has_edge_node1_node4 = any(
edge["source"] == node1_id and edge["target"] == node4_id for edge in edges
)
assert (
has_edge_node1_node2
), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id}{node2_id} 的边"
assert (
has_edge_node2_node3
), f"{chunk2_id}, {chunk3_id} 应包含 {node2_id}{node3_id} 的边"
assert (
has_edge_node1_node4
), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id}{node4_id} 的边"
print("\n批量操作测试完成")
return True