Merge branch 'pg-optimization'

This commit is contained in:
yangdx 2025-08-18 22:34:08 +08:00
commit ee15629f26

View file

@ -9,6 +9,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
import ssl import ssl
import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage):
query: str, query: str,
readonly: bool = True, readonly: bool = True,
upsert: bool = False, upsert: bool = False,
params: dict[str, Any] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage):
if readonly: if readonly:
data = await self.db.query( data = await self.db.query(
query, query,
params,
multirows=True, multirows=True,
with_age=True, with_age=True,
graph_name=self.graph_name, graph_name=self.graph_name,
@ -3384,12 +3387,15 @@ class PGGraphStorage(BaseGraphStorage):
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(
self, node_ids: list[str], batch_size: int = 1000
) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
Args: Args:
node_ids: List of node entity IDs to fetch. node_ids: List of node entity IDs to fetch.
batch_size: Batch size for the query
Returns: Returns:
A dictionary mapping each node_id to its node data (or None if not found). A dictionary mapping each node_id to its node data (or None if not found).
@ -3397,24 +3403,44 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) nid_norm = self._normalize_node_id(nid)
if nid_norm not in seen:
query = """SELECT * FROM cypher('%s', $$ seen.add(nid_norm)
UNWIND [%s] AS node_id unique_ids.append(nid_norm)
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
results = await self._query(query)
# Build result dictionary # Build result dictionary
nodes_dict = {} nodes_dict = {}
for i in range(0, len(unique_ids), batch_size):
batch = unique_ids[i : i + batch_size]
query = f"""
WITH input(v, ord) AS (
SELECT v, ord
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
)
SELECT i.node_id::text AS node_id,
b.properties
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
ORDER BY i.ord;
"""
results = await self._query(query, params={"ids": batch})
for result in results: for result in results:
if result["node_id"] and result["n"]: if result["node_id"] and result["properties"]:
node_dict = result["n"]["properties"] node_dict = result["properties"]
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(node_dict, str): if isinstance(node_dict, str):
@ -3422,20 +3448,16 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict) node_dict = json.loads(node_dict)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" f"Failed to parse node string in batch: {node_dict}"
) )
# Remove the 'base' label if present in a 'labels' property
# if "labels" in node_dict:
# node_dict["labels"] = [
# label for label in node_dict["labels"] if label != "base"
# ]
nodes_dict[result["node_id"]] = node_dict nodes_dict[result["node_id"]] = node_dict
return nodes_dict return nodes_dict
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(
self, node_ids: list[str], batch_size: int = 500
) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
Calculates the total degree by counting distinct relationships. Calculates the total degree by counting distinct relationships.
@ -3443,6 +3465,7 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_ids: List of node labels (entity_id values) to look up. node_ids: List of node labels (entity_id values) to look up.
batch_size: Batch size for the query
Returns: Returns:
A dictionary mapping each node_id to its degree (total number of relationships). A dictionary mapping each node_id to its degree (total number of relationships).
@ -3451,44 +3474,66 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids: list[str] = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) n = self._normalize_node_id(nid)
if n not in seen:
outgoing_query = """SELECT * FROM cypher('%s', $$ seen.add(n)
UNWIND [%s] AS node_id unique_ids.append(n)
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[r]->(a)
RETURN node_id, count(a) AS out_degree
$$) AS (node_id text, out_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)<-[r]-(b)
RETURN node_id, count(b) AS in_degree
$$) AS (node_id text, in_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
out_degrees = {} out_degrees = {}
in_degrees = {} in_degrees = {}
for result in outgoing_results: for i in range(0, len(unique_ids), batch_size):
if result["node_id"] is not None: batch = unique_ids[i : i + batch_size]
out_degrees[result["node_id"]] = int(result["out_degree"])
for result in incoming_results: query = f"""
if result["node_id"] is not None: WITH input(v, ord) AS (
in_degrees[result["node_id"]] = int(result["in_degree"]) SELECT v, ord
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
),
vids AS (
SELECT b.id AS vid, i.node_id, i.ord
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
),
deg_out AS (
SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.start_id
GROUP BY d.start_id
),
deg_in AS (
SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.end_id
GROUP BY d.end_id
)
SELECT v.node_id::text AS node_id,
COALESCE(o.out_degree, 0) AS out_degree,
COALESCE(n.in_degree, 0) AS in_degree
FROM vids v
LEFT JOIN deg_out o ON o.vid = v.vid
LEFT JOIN deg_in n ON n.vid = v.vid
ORDER BY v.ord;
"""
combined_results = await self._query(query, params={"ids": batch})
for row in combined_results:
node_id = row["node_id"]
if not node_id:
continue
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
degrees_dict = {} degrees_dict = {}
for node_id in node_ids: for node_id in node_ids:
@ -3532,7 +3577,7 @@ class PGGraphStorage(BaseGraphStorage):
return edge_degrees_dict return edge_degrees_dict
async def get_edges_batch( async def get_edges_batch(
self, pairs: list[dict[str, str]] self, pairs: list[dict[str, str]], batch_size: int = 500
) -> dict[tuple[str, str], dict]: ) -> dict[tuple[str, str], dict]:
""" """
Retrieve edge properties for multiple (src, tgt) pairs in one query. Retrieve edge properties for multiple (src, tgt) pairs in one query.
@ -3540,6 +3585,7 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
batch_size: Batch size for the query
Returns: Returns:
A dictionary mapping (src, tgt) tuples to their edge properties. A dictionary mapping (src, tgt) tuples to their edge properties.
@ -3547,33 +3593,64 @@ class PGGraphStorage(BaseGraphStorage):
if not pairs: if not pairs:
return {} return {}
src_nodes = [] seen = set()
tgt_nodes = [] uniq_pairs: list[dict[str, str]] = []
for pair in pairs: for p in pairs:
src_nodes.append(self._normalize_node_id(pair["src"])) s = self._normalize_node_id(p["src"])
tgt_nodes.append(self._normalize_node_id(pair["tgt"])) t = self._normalize_node_id(p["tgt"])
key = (s, t)
if s and t and key not in seen:
seen.add(key)
uniq_pairs.append(p)
src_array = ", ".join([f'"{src}"' for src in src_nodes]) edges_dict: dict[tuple[str, str], dict] = {}
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ for i in range(0, len(uniq_pairs), batch_size):
WITH [{src_array}] AS sources, [{tgt_array}] AS targets batch = uniq_pairs[i : i + batch_size]
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})-[r]->(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ pairs = [{"src": p["src"], "tgt": p["tgt"]} for p in batch]
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})<-[r]-(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
forward_results = await self._query(forward_query) forward_cypher = """
backward_results = await self._query(backward_query) UNWIND $pairs AS p
WITH p.src AS src_eid, p.tgt AS tgt_eid
MATCH (a:base {entity_id: src_eid})
MATCH (b:base {entity_id: tgt_eid})
MATCH (a)-[r]->(b)
RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
backward_cypher = """
UNWIND $pairs AS p
WITH p.src AS src_eid, p.tgt AS tgt_eid
MATCH (a:base {entity_id: src_eid})
MATCH (b:base {entity_id: tgt_eid})
MATCH (a)<-[r]-(b)
RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
edges_dict = {} def dollar_quote(s: str, tag_prefix="AGE"):
s = "" if s is None else str(s)
for i in itertools.count(1):
tag = f"{tag_prefix}{i}"
wrapper = f"${tag}$"
if wrapper not in s:
return f"{wrapper}{s}{wrapper}"
sql_fwd = f"""
SELECT * FROM cypher({dollar_quote(self.graph_name)}::name,
{dollar_quote(forward_cypher)}::cstring,
$1::agtype)
AS (source text, target text, edge_properties agtype)
"""
sql_bwd = f"""
SELECT * FROM cypher({dollar_quote(self.graph_name)}::name,
{dollar_quote(backward_cypher)}::cstring,
$1::agtype)
AS (source text, target text, edge_properties agtype)
"""
pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)}
forward_results = await self._query(sql_fwd, params=pg_params)
backward_results = await self._query(sql_bwd, params=pg_params)
for result in forward_results: for result in forward_results:
if result["source"] and result["target"] and result["edge_properties"]: if result["source"] and result["target"] and result["edge_properties"]:
@ -3585,7 +3662,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props) edge_props = json.loads(edge_props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" f"Failed to parse edge properties string: {edge_props}"
) )
continue continue
@ -3601,7 +3678,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props) edge_props = json.loads(edge_props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" f"Failed to parse edge properties string: {edge_props}"
) )
continue continue
@ -3610,13 +3687,14 @@ class PGGraphStorage(BaseGraphStorage):
return edges_dict return edges_dict
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
self, node_ids: list[str] self, node_ids: list[str], batch_size: int = 500
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]:
""" """
Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation. Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
Args: Args:
node_ids: List of node IDs to get edges for node_ids: List of node IDs to get edges for
batch_size: Batch size for the query
Returns: Returns:
Dictionary mapping node IDs to lists of (source, target) edge tuples Dictionary mapping node IDs to lists of (source, target) edge tuples
@ -3624,10 +3702,20 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
seen = set()
unique_ids: list[str] = []
for nid in node_ids:
n = self._normalize_node_id(nid)
if n and n not in seen:
seen.add(n)
unique_ids.append(n)
edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
for i in range(0, len(unique_ids), batch_size):
batch = unique_ids[i : i + batch_size]
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join( formatted_ids = ", ".join([f'"{n}"' for n in batch])
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids]
)
outgoing_query = """SELECT * FROM cypher('%s', $$ outgoing_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
@ -3652,21 +3740,24 @@ class PGGraphStorage(BaseGraphStorage):
outgoing_results = await self._query(outgoing_query) outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query) incoming_results = await self._query(incoming_query)
nodes_edges_dict = {node_id: [] for node_id in node_ids}
for result in outgoing_results: for result in outgoing_results:
if result["node_id"] and result["connected_id"]: if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append( edges_norm[result["node_id"]].append(
(result["node_id"], result["connected_id"]) (result["node_id"], result["connected_id"])
) )
for result in incoming_results: for result in incoming_results:
if result["node_id"] and result["connected_id"]: if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append( edges_norm[result["node_id"]].append(
(result["connected_id"], result["node_id"]) (result["connected_id"], result["node_id"])
) )
return nodes_edges_dict out: dict[str, list[tuple[str, str]]] = {}
for orig in node_ids:
n = self._normalize_node_id(orig)
out[orig] = edges_norm.get(n, [])
return out
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
@ -4491,49 +4582,85 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time
""", """,
"relationships": """ "relationships": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (SELECT id as chunk_id
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
) )
SELECT r.source_id as src_id, r.target_id as tgt_id, , rc AS (
EXTRACT(EPOCH FROM r.create_time)::BIGINT as created_at SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_RELATION r FROM LIGHTRAG_VDB_RELATION r
JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids)
WHERE r.workspace = $1 WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $3
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT $4 LIMIT ($4 * 50)
)
SELECT c.src_id,
c.tgt_id,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"entities": """ "entities": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (SELECT id as chunk_id
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
) )
SELECT e.entity_name, , rc AS (
EXTRACT(EPOCH FROM e.create_time)::BIGINT as created_at SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_ENTITY e FROM LIGHTRAG_VDB_ENTITY e
JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids)
WHERE e.workspace = $1 WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $3
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
LIMIT $4 LIMIT ($4 * 50)
)
SELECT c.entity_name,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
"chunks": """ "chunks": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (SELECT id as chunk_id
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) WHERE $2
:: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
) )
SELECT id, content, file_path, , rc AS (
EXTRACT(EPOCH FROM create_time)::BIGINT as created_at SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
), cand AS (
SELECT
id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist
FROM LIGHTRAG_VDB_CHUNKS FROM LIGHTRAG_VDB_CHUNKS
WHERE workspace = $1 WHERE workspace = $1
AND id IN (SELECT chunk_id FROM relevant_chunks)
AND content_vector <=> '[{embedding_string}]'::vector < $3
ORDER BY content_vector <=> '[{embedding_string}]'::vector ORDER BY content_vector <=> '[{embedding_string}]'::vector
LIMIT $4 LIMIT ($4 * 50)
)
SELECT c.id,
c.content,
c.file_path,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.id = ANY (rc.chunk_arr)
ORDER BY c.dist, c.id
LIMIT $4;
""", """,
# DROP tables # DROP tables
"drop_specifiy_table_workspace": """ "drop_specifiy_table_workspace": """