Add search on neighbor nodes which are source to selected one
This commit is contained in:
parent
f40bc43d5e
commit
7c8f65d020
1 changed files with 28 additions and 11 deletions
|
|
@ -710,11 +710,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _construct_graph_node(
|
def _construct_graph_node(
|
||||||
self, node_data: dict[str, str], seen_nodes: set
|
self, node_id, node_data: dict[str, str]
|
||||||
) -> KnowledgeGraphNode:
|
) -> KnowledgeGraphNode:
|
||||||
node_id = str(node_data["_id"])
|
return KnowledgeGraphNode(
|
||||||
|
|
||||||
graph_node = KnowledgeGraphNode(
|
|
||||||
id=node_id,
|
id=node_id,
|
||||||
labels=[node_id],
|
labels=[node_id],
|
||||||
properties={
|
properties={
|
||||||
|
|
@ -728,9 +726,6 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
seen_nodes.add(node_id)
|
|
||||||
|
|
||||||
return graph_node, seen_nodes
|
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self,
|
self,
|
||||||
|
|
@ -810,6 +805,25 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
"as": "connected_edges",
|
"as": "connected_edges",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"$unionWith": {
|
||||||
|
"coll": "chunk_entity_relation",
|
||||||
|
"pipeline": [
|
||||||
|
{"$match": {"_id": "LightRAG"}},
|
||||||
|
{
|
||||||
|
"$graphLookup": {
|
||||||
|
"from": "chunk_entity_relation_edges",
|
||||||
|
"startWith": "$_id",
|
||||||
|
"connectFromField": "source_node_id",
|
||||||
|
"connectToField": "target_node_id",
|
||||||
|
"as": "connected_edges",
|
||||||
|
"maxDepth": 3,
|
||||||
|
"depthField": "depth",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||||
|
|
@ -817,8 +831,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
# Add the start nodes
|
# Add the start nodes
|
||||||
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
|
node_id = str(doc["_id"])
|
||||||
result.nodes.append(graph_node)
|
if node_id not in seen_nodes:
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
result.nodes.append(self._construct_graph_node(node_id, doc))
|
||||||
|
|
||||||
if doc.get("connected_edges", []):
|
if doc.get("connected_edges", []):
|
||||||
node_edges.extend(doc.get("connected_edges"))
|
node_edges.extend(doc.get("connected_edges"))
|
||||||
|
|
@ -857,13 +873,14 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
node_ids.append(edge["target_node_id"])
|
node_ids.append(edge["target_node_id"])
|
||||||
seen_nodes.add(edge["target_node_id"])
|
seen_nodes.add(edge["target_node_id"])
|
||||||
|
|
||||||
|
# Filter out all the node whose id is same as label so that we do not check existence next step
|
||||||
cursor = self.collection.find(
|
cursor = self.collection.find(
|
||||||
{"_id": {"$in": [node for node in node_ids if node != label]}}
|
{"_id": {"$in": [node for node in node_ids if node != label]}}
|
||||||
)
|
)
|
||||||
|
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
|
node_id = str(doc["_id"])
|
||||||
result.nodes.append(graph_node)
|
result.nodes.append(self._construct_graph_node(node_id, doc))
|
||||||
|
|
||||||
for edge in node_edges:
|
for edge in node_edges:
|
||||||
if (
|
if (
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue