Remove unused chunk-based node/edge retrieval methods

This commit is contained in:
yangdx 2025-11-06 18:17:10 +08:00
parent 831e658ed8
commit 807d2461d3
6 changed files with 0 additions and 333 deletions

View file

@ -19,7 +19,6 @@ from typing import (
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph
from .constants import ( from .constants import (
GRAPH_FIELD_SEP,
DEFAULT_TOP_K, DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K, DEFAULT_CHUNK_TOP_K,
DEFAULT_MAX_ENTITY_TOKENS, DEFAULT_MAX_ENTITY_TOKENS,
@ -528,56 +527,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else [] result[node_id] = edges if edges is not None else []
return result return result
@abstractmethod
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
@abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
# Default implementation iterates through all nodes and their edges, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_edges = []
all_labels = await self.get_all_labels()
processed_edges = set()
for label in all_labels:
edges = await self.get_node_edges(label)
if edges:
for src_id, tgt_id in edges:
# Avoid processing the same edge twice in an undirected graph
edge_tuple = tuple(sorted((src_id, tgt_id)))
if edge_tuple in processed_edges:
continue
processed_edges.add(edge_tuple)
edge = await self.get_edge(src_id, tgt_id)
if edge and "source_id" in edge:
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
# Add source and target to the edge dict for easier processing later
edge_with_nodes = edge.copy()
edge_with_nodes["source"] = src_id
edge_with_nodes["target"] = tgt_id
all_edges.append(edge_with_nodes)
return all_edges
@abstractmethod @abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph. """Insert a new node or update an existing node in the graph.

View file

@ -8,7 +8,6 @@ import configparser
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
@ -784,79 +783,6 @@ class MemgraphStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated nodes for
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated edges for
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target
WITH a, b, r,
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
async def get_knowledge_graph( async def get_knowledge_graph(
self, self,
node_label: str, node_label: str,

View file

@ -1036,45 +1036,6 @@ class MongoGraphStorage(BaseGraphStorage):
return result return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if not chunk_ids:
return []
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
return [doc async for doc in cursor]
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if not chunk_ids:
return []
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
edges = []
async for edge in cursor:
edge["source"] = edge["source_node_id"]
edge["target"] = edge["target_node_id"]
edges.append(edge)
return edges
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# UPSERTS # UPSERTS

View file

@ -16,7 +16,6 @@ import logging
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
@ -904,49 +903,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are fully consumed await result.consume() # Ensure results are fully consumed
return edges_dict return edges_dict
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),

View file

@ -5,7 +5,6 @@ from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseGraphStorage from lightrag.base import BaseGraphStorage
from lightrag.constants import GRAPH_FIELD_SEP
import networkx as nx import networkx as nx
from .shared_storage import ( from .shared_storage import (
get_storage_lock, get_storage_lock,
@ -470,33 +469,6 @@ class NetworkXStorage(BaseGraphStorage):
) )
return result return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_nodes = []
for node_id, node_data in graph.nodes(data=True):
if "source_id" in node_data:
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
if not node_source_ids.isdisjoint(chunk_ids_set):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
matching_nodes.append(node_data_with_id)
return matching_nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_edges = []
for u, v, edge_data in graph.edges(data=True):
if "source_id" in edge_data:
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
if not edge_source_ids.isdisjoint(chunk_ids_set):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
matching_edges.append(edge_data_with_nodes)
return matching_edges
async def get_all_nodes(self) -> list[dict]: async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph. """Get all nodes in the graph.

View file

@ -33,7 +33,6 @@ from ..base import (
) )
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
import pipmaster as pm import pipmaster as pm
@ -4175,102 +4174,6 @@ class PGGraphStorage(BaseGraphStorage):
labels.append(result["label"]) labels.append(result["label"])
return labels return labels
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all nodes
where the `source_id` property contains any of the specified chunk IDs.
"""
# The string representation of the list for the cypher query
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
RETURN n
$$) AS (n agtype);
"""
results = await self._query(query)
# Build result list
nodes = []
for result in results:
if result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
node_dict["id"] = node_dict["entity_id"]
nodes.append(node_dict)
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves edges from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all edges
where the `source_id` property contains any of the specified chunk IDs.
"""
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH ()-[r]-()
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
$$) AS (edge agtype, source agtype, target agtype);
"""
results = await self._query(query)
edges = []
if results:
for item in results:
edge_agtype = item["edge"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_agtype, str):
try:
edge_agtype = json.loads(edge_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}"
)
source_agtype = item["source"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(source_agtype, str):
try:
source_agtype = json.loads(source_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}"
)
target_agtype = item["target"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(target_agtype, str):
try:
target_agtype = json.loads(target_agtype)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}"
)
if edge_agtype and source_agtype and target_agtype:
edge_properties = edge_agtype
edge_properties["source"] = source_agtype["entity_id"]
edge_properties["target"] = target_agtype["entity_id"]
edges.append(edge_properties)
return edges
async def _bfs_subgraph( async def _bfs_subgraph(
self, node_label: str, max_depth: int, max_nodes: int self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph: ) -> KnowledgeGraph: