From a7da48e05c0bab4c8a9c116491b1df3556183e9c Mon Sep 17 00:00:00 2001 From: Matt23-star Date: Sat, 16 Aug 2025 22:35:22 +0800 Subject: [PATCH] feat: add batch size parameter to node and edge retrieval methods --- lightrag/kg/postgres_impl.py | 127 ++++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 40 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index e50128c9..18957699 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -3384,12 +3384,13 @@ class PGGraphStorage(BaseGraphStorage): logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") raise - async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + async def get_nodes_batch(self, node_ids: list[str], batch_size: int = 1000) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. Args: node_ids: List of node entity IDs to fetch. + batch_size: Batch size for the query Returns: A dictionary mapping each node_id to its node data (or None if not found). @@ -3435,7 +3436,7 @@ class PGGraphStorage(BaseGraphStorage): return nodes_dict - async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + async def node_degrees_batch(self, node_ids: list[str], batch_size: int = 500) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. Calculates the total degree by counting distinct relationships. @@ -3443,6 +3444,7 @@ class PGGraphStorage(BaseGraphStorage): Args: node_ids: List of node labels (entity_id values) to look up. + batch_size: Batch size for the query Returns: A dictionary mapping each node_id to its degree (total number of relationships). @@ -3532,7 +3534,7 @@ class PGGraphStorage(BaseGraphStorage): return edge_degrees_dict async def get_edges_batch( - self, pairs: list[dict[str, str]] + self, pairs: list[dict[str, str]], batch_size: int = 500 ) -> dict[tuple[str, str], dict]: """ Retrieve edge properties for multiple (src, tgt) pairs in one query. @@ -3540,6 +3542,7 @@ class PGGraphStorage(BaseGraphStorage): Args: pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] + batch_size: Batch size for the query Returns: A dictionary mapping (src, tgt) tuples to their edge properties. @@ -3547,53 +3550,87 @@ class PGGraphStorage(BaseGraphStorage): if not pairs: return {} - src_nodes = [] - tgt_nodes = [] - for pair in pairs: - src_nodes.append(self._normalize_node_id(pair["src"])) - tgt_nodes.append(self._normalize_node_id(pair["tgt"])) + seen = set() + uniq_pairs: list[dict[str, str]] = [] + for p in pairs: + s = self._normalize_node_id(p["src"]) + t = self._normalize_node_id(p["tgt"]) + key = (s, t) + if s and t and key not in seen: + seen.add(key) + uniq_pairs.append(p) - src_array = ", ".join([f'"{src}"' for src in src_nodes]) - tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) + edges_dict: dict[tuple[str, str], dict] = {} - forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{src_array}] AS sources, [{tgt_array}] AS targets - UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})-[r]->(b:base {{entity_id: targets[i]}}) - RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties - $$) AS (source text, target text, edge_properties agtype)""" + for i in range(0, len(uniq_pairs), batch_size): + batch = uniq_pairs[i:i + batch_size] - backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{src_array}] AS sources, [{tgt_array}] AS targets - UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})<-[r]-(b:base {{entity_id: targets[i]}}) - RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties - $$) AS (source text, target text, edge_properties agtype)""" + pairs = [{"src": p["src"], "tgt": p["tgt"]} for p in batch] - forward_results = await self._query(forward_query) - backward_results = await self._query(backward_query) + forward_cypher = """ + UNWIND $pairs AS p + WITH p.src AS src_eid, p.tgt AS tgt_eid + MATCH (a:base {entity_id: src_eid}) + MATCH (b:base {entity_id: tgt_eid}) + MATCH (a)-[r]->(b) + RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties""" + backward_cypher = """ + UNWIND $pairs AS p + WITH p.src AS src_eid, p.tgt AS tgt_eid + MATCH (a:base {entity_id: src_eid}) + MATCH (b:base {entity_id: tgt_eid}) + MATCH (a)<-[r]-(b) + RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties""" - edges_dict = {} + def dollar_quote(s: str, tag_prefix="AGE"): + s = "" if s is None else str(s) + for i in itertools.count(1): + tag = f"{tag_prefix}{i}" + wrapper = f"${tag}$" + if wrapper not in s: + return f"{wrapper}{s}{wrapper}" - for result in forward_results: - if result["source"] and result["target"] and result["edge_properties"]: - edge_props = result["edge_properties"] + sql_fwd = f""" + SELECT * FROM cypher({dollar_quote(self.graph_name)}::name, + {dollar_quote(forward_cypher)}::cstring, + $1::agtype) + AS (source text, target text, edge_properties agtype) + """ - # Process string result, parse it to JSON dictionary - if isinstance(edge_props, str): - try: - edge_props = json.loads(edge_props) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" - ) - continue + sql_bwd = f""" + SELECT * FROM cypher({dollar_quote(self.graph_name)}::name, + {dollar_quote(backward_cypher)}::cstring, + $1::agtype) + AS (source text, target text, edge_properties agtype) + """ - edges_dict[(result["source"], result["target"])] = edge_props + pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)} + + forward_results = await self._query(sql_fwd, params=pg_params) + backward_results = await self._query(sql_bwd, params=pg_params) + + for result in forward_results: + if result["source"] and result["target"] and result["edge_properties"]: + edge_props = result["edge_properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(edge_props, str): + try: + edge_props = json.loads(edge_props) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge properties string: {edge_props}" + ) + continue + + edges_dict[(result["source"], result["target"])] = edge_props for result in backward_results: if result["source"] and result["target"] and result["edge_properties"]: edge_props = result["edge_properties"] + for result in backward_results: + if result["source"] and result["target"] and result["edge_properties"]: + edge_props = result["edge_properties"] # Process string result, parse it to JSON dictionary if isinstance(edge_props, str): @@ -3604,19 +3641,29 @@ class PGGraphStorage(BaseGraphStorage): f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" ) continue + # Process string result, parse it to JSON dictionary + if isinstance(edge_props, str): + try: + edge_props = json.loads(edge_props) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge properties string: {edge_props}" + ) + continue - edges_dict[(result["source"], result["target"])] = edge_props + edges_dict[(result["source"], result["target"])] = edge_props return edges_dict async def get_nodes_edges_batch( - self, node_ids: list[str] + self, node_ids: list[str], batch_size: int = 500 ) -> dict[str, list[tuple[str, str]]]: """ Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation. Args: node_ids: List of node IDs to get edges for + batch_size: Batch size for the query Returns: Dictionary mapping node IDs to lists of (source, target) edge tuples