diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index e50128c9..88a75ba5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -9,6 +9,7 @@ from typing import Any, Union, final import numpy as np import configparser import ssl +import itertools from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge @@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage): query: str, readonly: bool = True, upsert: bool = False, + params: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an @@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage): if readonly: data = await self.db.query( query, + params, multirows=True, with_age=True, graph_name=self.graph_name, @@ -3384,12 +3387,15 @@ class PGGraphStorage(BaseGraphStorage): logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") raise - async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + async def get_nodes_batch( + self, node_ids: list[str], batch_size: int = 1000 + ) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. Args: node_ids: List of node entity IDs to fetch. + batch_size: Batch size for the query Returns: A dictionary mapping each node_id to its node data (or None if not found). @@ -3397,45 +3403,61 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) - - query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - RETURN node_id, n - $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids) - - results = await self._query(query) + seen = set() + unique_ids = [] + for nid in node_ids: + nid_norm = self._normalize_node_id(nid) + if nid_norm not in seen: + seen.add(nid_norm) + unique_ids.append(nid_norm) # Build result dictionary nodes_dict = {} - for result in results: - if result["node_id"] and result["n"]: - node_dict = result["n"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" - ) + for i in range(0, len(unique_ids), batch_size): + batch = unique_ids[i : i + batch_size] - # Remove the 'base' label if present in a 'labels' property - # if "labels" in node_dict: - # node_dict["labels"] = [ - # label for label in node_dict["labels"] if label != "base" - # ] + query = f""" + WITH input(v, ord) AS ( + SELECT v, ord + FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord) + ), + ids(node_id, ord) AS ( + SELECT (to_json(v)::text)::agtype AS node_id, ord + FROM input + ) + SELECT i.node_id::text AS node_id, + b.properties + FROM {self.graph_name}.base AS b + JOIN ids i + ON ag_catalog.agtype_access_operator( + VARIADIC ARRAY[b.properties, '"entity_id"'::agtype] + ) = i.node_id + ORDER BY i.ord; + """ - nodes_dict[result["node_id"]] = node_dict + results = await self._query(query, params={"ids": batch}) + + for result in results: + if result["node_id"] and result["properties"]: + node_dict = result["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {node_dict}" + ) + + nodes_dict[result["node_id"]] = node_dict return nodes_dict - async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + async def node_degrees_batch( + self, node_ids: list[str], batch_size: int = 500 + ) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. Calculates the total degree by counting distinct relationships. @@ -3443,6 +3465,7 @@ class PGGraphStorage(BaseGraphStorage): Args: node_ids: List of node labels (entity_id values) to look up. + batch_size: Batch size for the query Returns: A dictionary mapping each node_id to its degree (total number of relationships). @@ -3451,44 +3474,66 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) - - outgoing_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n)-[r]->(a) - RETURN node_id, count(a) AS out_degree - $$) AS (node_id text, out_degree bigint)""" % ( - self.graph_name, - formatted_ids, - ) - - incoming_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n)<-[r]-(b) - RETURN node_id, count(b) AS in_degree - $$) AS (node_id text, in_degree bigint)""" % ( - self.graph_name, - formatted_ids, - ) - - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) + seen = set() + unique_ids: list[str] = [] + for nid in node_ids: + n = self._normalize_node_id(nid) + if n not in seen: + seen.add(n) + unique_ids.append(n) out_degrees = {} in_degrees = {} - for result in outgoing_results: - if result["node_id"] is not None: - out_degrees[result["node_id"]] = int(result["out_degree"]) + for i in range(0, len(unique_ids), batch_size): + batch = unique_ids[i : i + batch_size] - for result in incoming_results: - if result["node_id"] is not None: - in_degrees[result["node_id"]] = int(result["in_degree"]) + query = f""" + WITH input(v, ord) AS ( + SELECT v, ord + FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord) + ), + ids(node_id, ord) AS ( + SELECT (to_json(v)::text)::agtype AS node_id, ord + FROM input + ), + vids AS ( + SELECT b.id AS vid, i.node_id, i.ord + FROM {self.graph_name}.base AS b + JOIN ids i + ON ag_catalog.agtype_access_operator( + VARIADIC ARRAY[b.properties, '"entity_id"'::agtype] + ) = i.node_id + ), + deg_out AS ( + SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree + FROM {self.graph_name}."DIRECTED" AS d + JOIN vids v ON v.vid = d.start_id + GROUP BY d.start_id + ), + deg_in AS ( + SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree + FROM {self.graph_name}."DIRECTED" AS d + JOIN vids v ON v.vid = d.end_id + GROUP BY d.end_id + ) + SELECT v.node_id::text AS node_id, + COALESCE(o.out_degree, 0) AS out_degree, + COALESCE(n.in_degree, 0) AS in_degree + FROM vids v + LEFT JOIN deg_out o ON o.vid = v.vid + LEFT JOIN deg_in n ON n.vid = v.vid + ORDER BY v.ord; + """ + + combined_results = await self._query(query, params={"ids": batch}) + + for row in combined_results: + node_id = row["node_id"] + if not node_id: + continue + out_degrees[node_id] = int(row.get("out_degree", 0) or 0) + in_degrees[node_id] = int(row.get("in_degree", 0) or 0) degrees_dict = {} for node_id in node_ids: @@ -3532,7 +3577,7 @@ class PGGraphStorage(BaseGraphStorage): return edge_degrees_dict async def get_edges_batch( - self, pairs: list[dict[str, str]] + self, pairs: list[dict[str, str]], batch_size: int = 500 ) -> dict[tuple[str, str], dict]: """ Retrieve edge properties for multiple (src, tgt) pairs in one query. @@ -3540,6 +3585,7 @@ class PGGraphStorage(BaseGraphStorage): Args: pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] + batch_size: Batch size for the query Returns: A dictionary mapping (src, tgt) tuples to their edge properties. @@ -3547,76 +3593,108 @@ class PGGraphStorage(BaseGraphStorage): if not pairs: return {} - src_nodes = [] - tgt_nodes = [] - for pair in pairs: - src_nodes.append(self._normalize_node_id(pair["src"])) - tgt_nodes.append(self._normalize_node_id(pair["tgt"])) + seen = set() + uniq_pairs: list[dict[str, str]] = [] + for p in pairs: + s = self._normalize_node_id(p["src"]) + t = self._normalize_node_id(p["tgt"]) + key = (s, t) + if s and t and key not in seen: + seen.add(key) + uniq_pairs.append(p) - src_array = ", ".join([f'"{src}"' for src in src_nodes]) - tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) + edges_dict: dict[tuple[str, str], dict] = {} - forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{src_array}] AS sources, [{tgt_array}] AS targets - UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})-[r]->(b:base {{entity_id: targets[i]}}) - RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties - $$) AS (source text, target text, edge_properties agtype)""" + for i in range(0, len(uniq_pairs), batch_size): + batch = uniq_pairs[i : i + batch_size] - backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{src_array}] AS sources, [{tgt_array}] AS targets - UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})<-[r]-(b:base {{entity_id: targets[i]}}) - RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties - $$) AS (source text, target text, edge_properties agtype)""" + pairs = [{"src": p["src"], "tgt": p["tgt"]} for p in batch] - forward_results = await self._query(forward_query) - backward_results = await self._query(backward_query) + forward_cypher = """ + UNWIND $pairs AS p + WITH p.src AS src_eid, p.tgt AS tgt_eid + MATCH (a:base {entity_id: src_eid}) + MATCH (b:base {entity_id: tgt_eid}) + MATCH (a)-[r]->(b) + RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties""" + backward_cypher = """ + UNWIND $pairs AS p + WITH p.src AS src_eid, p.tgt AS tgt_eid + MATCH (a:base {entity_id: src_eid}) + MATCH (b:base {entity_id: tgt_eid}) + MATCH (a)<-[r]-(b) + RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties""" - edges_dict = {} + def dollar_quote(s: str, tag_prefix="AGE"): + s = "" if s is None else str(s) + for i in itertools.count(1): + tag = f"{tag_prefix}{i}" + wrapper = f"${tag}$" + if wrapper not in s: + return f"{wrapper}{s}{wrapper}" - for result in forward_results: - if result["source"] and result["target"] and result["edge_properties"]: - edge_props = result["edge_properties"] + sql_fwd = f""" + SELECT * FROM cypher({dollar_quote(self.graph_name)}::name, + {dollar_quote(forward_cypher)}::cstring, + $1::agtype) + AS (source text, target text, edge_properties agtype) + """ - # Process string result, parse it to JSON dictionary - if isinstance(edge_props, str): - try: - edge_props = json.loads(edge_props) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" - ) - continue + sql_bwd = f""" + SELECT * FROM cypher({dollar_quote(self.graph_name)}::name, + {dollar_quote(backward_cypher)}::cstring, + $1::agtype) + AS (source text, target text, edge_properties agtype) + """ - edges_dict[(result["source"], result["target"])] = edge_props + pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)} - for result in backward_results: - if result["source"] and result["target"] and result["edge_properties"]: - edge_props = result["edge_properties"] + forward_results = await self._query(sql_fwd, params=pg_params) + backward_results = await self._query(sql_bwd, params=pg_params) - # Process string result, parse it to JSON dictionary - if isinstance(edge_props, str): - try: - edge_props = json.loads(edge_props) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" - ) - continue + for result in forward_results: + if result["source"] and result["target"] and result["edge_properties"]: + edge_props = result["edge_properties"] - edges_dict[(result["source"], result["target"])] = edge_props + # Process string result, parse it to JSON dictionary + if isinstance(edge_props, str): + try: + edge_props = json.loads(edge_props) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge properties string: {edge_props}" + ) + continue + + edges_dict[(result["source"], result["target"])] = edge_props + + for result in backward_results: + if result["source"] and result["target"] and result["edge_properties"]: + edge_props = result["edge_properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(edge_props, str): + try: + edge_props = json.loads(edge_props) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge properties string: {edge_props}" + ) + continue + + edges_dict[(result["source"], result["target"])] = edge_props return edges_dict async def get_nodes_edges_batch( - self, node_ids: list[str] + self, node_ids: list[str], batch_size: int = 500 ) -> dict[str, list[tuple[str, str]]]: """ Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation. Args: node_ids: List of node IDs to get edges for + batch_size: Batch size for the query Returns: Dictionary mapping node IDs to lists of (source, target) edge tuples @@ -3624,49 +3702,62 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) + seen = set() + unique_ids: list[str] = [] + for nid in node_ids: + n = self._normalize_node_id(nid) + if n and n not in seen: + seen.add(n) + unique_ids.append(n) - outgoing_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n:base)-[]->(connected:base) - RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) + edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids} - incoming_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n:base)<-[]-(connected:base) - RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) + for i in range(0, len(unique_ids), batch_size): + batch = unique_ids[i : i + batch_size] + # Format node IDs for the query + formatted_ids = ", ".join([f'"{n}"' for n in batch]) - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) + outgoing_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n:base)-[]->(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids, + ) - nodes_edges_dict = {node_id: [] for node_id in node_ids} + incoming_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n:base)<-[]-(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids, + ) - for result in outgoing_results: - if result["node_id"] and result["connected_id"]: - nodes_edges_dict[result["node_id"]].append( - (result["node_id"], result["connected_id"]) - ) + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) - for result in incoming_results: - if result["node_id"] and result["connected_id"]: - nodes_edges_dict[result["node_id"]].append( - (result["connected_id"], result["node_id"]) - ) + for result in outgoing_results: + if result["node_id"] and result["connected_id"]: + edges_norm[result["node_id"]].append( + (result["node_id"], result["connected_id"]) + ) - return nodes_edges_dict + for result in incoming_results: + if result["node_id"] and result["connected_id"]: + edges_norm[result["node_id"]].append( + (result["connected_id"], result["node_id"]) + ) + + out: dict[str, list[tuple[str, str]]] = {} + for orig in node_ids: + n = self._normalize_node_id(orig) + out[orig] = edges_norm.get(n, []) + + return out async def get_all_labels(self) -> list[str]: """ @@ -4491,50 +4582,86 @@ SQL_TEMPLATES = { update_time = EXCLUDED.update_time """, "relationships": """ - WITH relevant_chunks AS ( - SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) - ) - SELECT r.source_id as src_id, r.target_id as tgt_id, - EXTRACT(EPOCH FROM r.create_time)::BIGINT as created_at - FROM LIGHTRAG_VDB_RELATION r - JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids) - WHERE r.workspace = $1 - AND r.content_vector <=> '[{embedding_string}]'::vector < $3 - ORDER BY r.content_vector <=> '[{embedding_string}]'::vector - LIMIT $4 - """, + WITH relevant_chunks AS (SELECT id as chunk_id + FROM LIGHTRAG_VDB_CHUNKS + WHERE $2 + :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) + ) + , rc AS ( + SELECT array_agg(chunk_id) AS chunk_arr + FROM relevant_chunks + ), cand AS ( + SELECT + r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist + FROM LIGHTRAG_VDB_RELATION r + WHERE r.workspace = $1 + ORDER BY r.content_vector <=> '[{embedding_string}]'::vector + LIMIT ($4 * 50) + ) + SELECT c.src_id, + c.tgt_id, + EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at + FROM cand c + JOIN rc ON TRUE + WHERE c.dist < $3 + AND c.chunk_ids && (rc.chunk_arr::varchar[]) + ORDER BY c.dist, c.id + LIMIT $4; + """, "entities": """ - WITH relevant_chunks AS ( - SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) - ) - SELECT e.entity_name, - EXTRACT(EPOCH FROM e.create_time)::BIGINT as created_at - FROM LIGHTRAG_VDB_ENTITY e - JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids) - WHERE e.workspace = $1 - AND e.content_vector <=> '[{embedding_string}]'::vector < $3 - ORDER BY e.content_vector <=> '[{embedding_string}]'::vector - LIMIT $4 - """, + WITH relevant_chunks AS (SELECT id as chunk_id + FROM LIGHTRAG_VDB_CHUNKS + WHERE $2 + :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) + ) + , rc AS ( + SELECT array_agg(chunk_id) AS chunk_arr + FROM relevant_chunks + ), cand AS ( + SELECT + e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist + FROM LIGHTRAG_VDB_ENTITY e + WHERE e.workspace = $1 + ORDER BY e.content_vector <=> '[{embedding_string}]'::vector + LIMIT ($4 * 50) + ) + SELECT c.entity_name, + EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at + FROM cand c + JOIN rc ON TRUE + WHERE c.dist < $3 + AND c.chunk_ids && (rc.chunk_arr::varchar[]) + ORDER BY c.dist, c.id + LIMIT $4; + """, "chunks": """ - WITH relevant_chunks AS ( - SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) - ) - SELECT id, content, file_path, - EXTRACT(EPOCH FROM create_time)::BIGINT as created_at - FROM LIGHTRAG_VDB_CHUNKS - WHERE workspace = $1 - AND id IN (SELECT chunk_id FROM relevant_chunks) - AND content_vector <=> '[{embedding_string}]'::vector < $3 - ORDER BY content_vector <=> '[{embedding_string}]'::vector - LIMIT $4 - """, + WITH relevant_chunks AS (SELECT id as chunk_id + FROM LIGHTRAG_VDB_CHUNKS + WHERE $2 + :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) + ) + , rc AS ( + SELECT array_agg(chunk_id) AS chunk_arr + FROM relevant_chunks + ), cand AS ( + SELECT + id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist + FROM LIGHTRAG_VDB_CHUNKS + WHERE workspace = $1 + ORDER BY content_vector <=> '[{embedding_string}]'::vector + LIMIT ($4 * 50) + ) + SELECT c.id, + c.content, + c.file_path, + EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at + FROM cand c + JOIN rc ON TRUE + WHERE c.dist < $3 + AND c.id = ANY (rc.chunk_arr) + ORDER BY c.dist, c.id + LIMIT $4; + """, # DROP tables "drop_specifiy_table_workspace": """ DELETE FROM {table_name} WHERE workspace=$1