Merge branch 'pg-optimization'

This commit is contained in:
yangdx 2025-08-18 22:34:08 +08:00
commit ee15629f26

View file

@ -9,6 +9,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
import ssl import ssl
import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@ -3051,6 +3052,7 @@ class PGGraphStorage(BaseGraphStorage):
query: str, query: str,
readonly: bool = True, readonly: bool = True,
upsert: bool = False, upsert: bool = False,
params: dict[str, Any] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
@ -3066,6 +3068,7 @@ class PGGraphStorage(BaseGraphStorage):
if readonly: if readonly:
data = await self.db.query( data = await self.db.query(
query, query,
params,
multirows=True, multirows=True,
with_age=True, with_age=True,
graph_name=self.graph_name, graph_name=self.graph_name,
@ -3384,12 +3387,15 @@ class PGGraphStorage(BaseGraphStorage):
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(
self, node_ids: list[str], batch_size: int = 1000
) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
Args: Args:
node_ids: List of node entity IDs to fetch. node_ids: List of node entity IDs to fetch.
batch_size: Batch size for the query
Returns: Returns:
A dictionary mapping each node_id to its node data (or None if not found). A dictionary mapping each node_id to its node data (or None if not found).
@ -3397,45 +3403,61 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) nid_norm = self._normalize_node_id(nid)
if nid_norm not in seen:
query = """SELECT * FROM cypher('%s', $$ seen.add(nid_norm)
UNWIND [%s] AS node_id unique_ids.append(nid_norm)
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
results = await self._query(query)
# Build result dictionary # Build result dictionary
nodes_dict = {} nodes_dict = {}
for result in results:
if result["node_id"] and result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary for i in range(0, len(unique_ids), batch_size):
if isinstance(node_dict, str): batch = unique_ids[i : i + batch_size]
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
)
# Remove the 'base' label if present in a 'labels' property query = f"""
# if "labels" in node_dict: WITH input(v, ord) AS (
# node_dict["labels"] = [ SELECT v, ord
# label for label in node_dict["labels"] if label != "base" FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
# ] ),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
)
SELECT i.node_id::text AS node_id,
b.properties
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
ORDER BY i.ord;
"""
nodes_dict[result["node_id"]] = node_dict results = await self._query(query, params={"ids": batch})
for result in results:
if result["node_id"] and result["properties"]:
node_dict = result["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
)
nodes_dict[result["node_id"]] = node_dict
return nodes_dict return nodes_dict
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(
self, node_ids: list[str], batch_size: int = 500
) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
Calculates the total degree by counting distinct relationships. Calculates the total degree by counting distinct relationships.
@ -3443,6 +3465,7 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
node_ids: List of node labels (entity_id values) to look up. node_ids: List of node labels (entity_id values) to look up.
batch_size: Batch size for the query
Returns: Returns:
A dictionary mapping each node_id to its degree (total number of relationships). A dictionary mapping each node_id to its degree (total number of relationships).
@ -3451,44 +3474,66 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids: list[str] = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) n = self._normalize_node_id(nid)
if n not in seen:
outgoing_query = """SELECT * FROM cypher('%s', $$ seen.add(n)
UNWIND [%s] AS node_id unique_ids.append(n)
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[r]->(a)
RETURN node_id, count(a) AS out_degree
$$) AS (node_id text, out_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)<-[r]-(b)
RETURN node_id, count(b) AS in_degree
$$) AS (node_id text, in_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
out_degrees = {} out_degrees = {}
in_degrees = {} in_degrees = {}
for result in outgoing_results: for i in range(0, len(unique_ids), batch_size):
if result["node_id"] is not None: batch = unique_ids[i : i + batch_size]
out_degrees[result["node_id"]] = int(result["out_degree"])
for result in incoming_results: query = f"""
if result["node_id"] is not None: WITH input(v, ord) AS (
in_degrees[result["node_id"]] = int(result["in_degree"]) SELECT v, ord
FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
),
ids(node_id, ord) AS (
SELECT (to_json(v)::text)::agtype AS node_id, ord
FROM input
),
vids AS (
SELECT b.id AS vid, i.node_id, i.ord
FROM {self.graph_name}.base AS b
JOIN ids i
ON ag_catalog.agtype_access_operator(
VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
) = i.node_id
),
deg_out AS (
SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.start_id
GROUP BY d.start_id
),
deg_in AS (
SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
FROM {self.graph_name}."DIRECTED" AS d
JOIN vids v ON v.vid = d.end_id
GROUP BY d.end_id
)
SELECT v.node_id::text AS node_id,
COALESCE(o.out_degree, 0) AS out_degree,
COALESCE(n.in_degree, 0) AS in_degree
FROM vids v
LEFT JOIN deg_out o ON o.vid = v.vid
LEFT JOIN deg_in n ON n.vid = v.vid
ORDER BY v.ord;
"""
combined_results = await self._query(query, params={"ids": batch})
for row in combined_results:
node_id = row["node_id"]
if not node_id:
continue
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
degrees_dict = {} degrees_dict = {}
for node_id in node_ids: for node_id in node_ids:
@ -3532,7 +3577,7 @@ class PGGraphStorage(BaseGraphStorage):
return edge_degrees_dict return edge_degrees_dict
async def get_edges_batch( async def get_edges_batch(
self, pairs: list[dict[str, str]] self, pairs: list[dict[str, str]], batch_size: int = 500
) -> dict[tuple[str, str], dict]: ) -> dict[tuple[str, str], dict]:
""" """
Retrieve edge properties for multiple (src, tgt) pairs in one query. Retrieve edge properties for multiple (src, tgt) pairs in one query.
@ -3540,6 +3585,7 @@ class PGGraphStorage(BaseGraphStorage):
Args: Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
batch_size: Batch size for the query
Returns: Returns:
A dictionary mapping (src, tgt) tuples to their edge properties. A dictionary mapping (src, tgt) tuples to their edge properties.
@ -3547,76 +3593,108 @@ class PGGraphStorage(BaseGraphStorage):
if not pairs: if not pairs:
return {} return {}
src_nodes = [] seen = set()
tgt_nodes = [] uniq_pairs: list[dict[str, str]] = []
for pair in pairs: for p in pairs:
src_nodes.append(self._normalize_node_id(pair["src"])) s = self._normalize_node_id(p["src"])
tgt_nodes.append(self._normalize_node_id(pair["tgt"])) t = self._normalize_node_id(p["tgt"])
key = (s, t)
if s and t and key not in seen:
seen.add(key)
uniq_pairs.append(p)
src_array = ", ".join([f'"{src}"' for src in src_nodes]) edges_dict: dict[tuple[str, str], dict] = {}
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ for i in range(0, len(uniq_pairs), batch_size):
WITH [{src_array}] AS sources, [{tgt_array}] AS targets batch = uniq_pairs[i : i + batch_size]
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})-[r]->(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ pairs = [{"src": p["src"], "tgt": p["tgt"]} for p in batch]
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})<-[r]-(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
forward_results = await self._query(forward_query) forward_cypher = """
backward_results = await self._query(backward_query) UNWIND $pairs AS p
WITH p.src AS src_eid, p.tgt AS tgt_eid
MATCH (a:base {entity_id: src_eid})
MATCH (b:base {entity_id: tgt_eid})
MATCH (a)-[r]->(b)
RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
backward_cypher = """
UNWIND $pairs AS p
WITH p.src AS src_eid, p.tgt AS tgt_eid
MATCH (a:base {entity_id: src_eid})
MATCH (b:base {entity_id: tgt_eid})
MATCH (a)<-[r]-(b)
RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
edges_dict = {} def dollar_quote(s: str, tag_prefix="AGE"):
s = "" if s is None else str(s)
for i in itertools.count(1):
tag = f"{tag_prefix}{i}"
wrapper = f"${tag}$"
if wrapper not in s:
return f"{wrapper}{s}{wrapper}"
for result in forward_results: sql_fwd = f"""
if result["source"] and result["target"] and result["edge_properties"]: SELECT * FROM cypher({dollar_quote(self.graph_name)}::name,
edge_props = result["edge_properties"] {dollar_quote(forward_cypher)}::cstring,
$1::agtype)
AS (source text, target text, edge_properties agtype)
"""
# Process string result, parse it to JSON dictionary sql_bwd = f"""
if isinstance(edge_props, str): SELECT * FROM cypher({dollar_quote(self.graph_name)}::name,
try: {dollar_quote(backward_cypher)}::cstring,
edge_props = json.loads(edge_props) $1::agtype)
except json.JSONDecodeError: AS (source text, target text, edge_properties agtype)
logger.warning( """
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
edges_dict[(result["source"], result["target"])] = edge_props pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)}
for result in backward_results: forward_results = await self._query(sql_fwd, params=pg_params)
if result["source"] and result["target"] and result["edge_properties"]: backward_results = await self._query(sql_bwd, params=pg_params)
edge_props = result["edge_properties"]
# Process string result, parse it to JSON dictionary for result in forward_results:
if isinstance(edge_props, str): if result["source"] and result["target"] and result["edge_properties"]:
try: edge_props = result["edge_properties"]
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
)
continue
edges_dict[(result["source"], result["target"])] = edge_props # Process string result, parse it to JSON dictionary
if isinstance(edge_props, str):
try:
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
)
continue
edges_dict[(result["source"], result["target"])] = edge_props
for result in backward_results:
if result["source"] and result["target"] and result["edge_properties"]:
edge_props = result["edge_properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_props, str):
try:
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge properties string: {edge_props}"
)
continue
edges_dict[(result["source"], result["target"])] = edge_props
return edges_dict return edges_dict
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
self, node_ids: list[str] self, node_ids: list[str], batch_size: int = 500
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]:
""" """
Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation. Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
Args: Args:
node_ids: List of node IDs to get edges for node_ids: List of node IDs to get edges for
batch_size: Batch size for the query
Returns: Returns:
Dictionary mapping node IDs to lists of (source, target) edge tuples Dictionary mapping node IDs to lists of (source, target) edge tuples
@ -3624,49 +3702,62 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
# Format node IDs for the query seen = set()
formatted_ids = ", ".join( unique_ids: list[str] = []
['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] for nid in node_ids:
) n = self._normalize_node_id(nid)
if n and n not in seen:
seen.add(n)
unique_ids.append(n)
outgoing_query = """SELECT * FROM cypher('%s', $$ edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)-[]->(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
incoming_query = """SELECT * FROM cypher('%s', $$ for i in range(0, len(unique_ids), batch_size):
UNWIND [%s] AS node_id batch = unique_ids[i : i + batch_size]
MATCH (n:base {entity_id: node_id}) # Format node IDs for the query
OPTIONAL MATCH (n:base)<-[]-(connected:base) formatted_ids = ", ".join([f'"{n}"' for n in batch])
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query) outgoing_query = """SELECT * FROM cypher('%s', $$
incoming_results = await self._query(incoming_query) UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)-[]->(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
nodes_edges_dict = {node_id: [] for node_id in node_ids} incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)<-[]-(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
for result in outgoing_results: outgoing_results = await self._query(outgoing_query)
if result["node_id"] and result["connected_id"]: incoming_results = await self._query(incoming_query)
nodes_edges_dict[result["node_id"]].append(
(result["node_id"], result["connected_id"])
)
for result in incoming_results: for result in outgoing_results:
if result["node_id"] and result["connected_id"]: if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append( edges_norm[result["node_id"]].append(
(result["connected_id"], result["node_id"]) (result["node_id"], result["connected_id"])
) )
return nodes_edges_dict for result in incoming_results:
if result["node_id"] and result["connected_id"]:
edges_norm[result["node_id"]].append(
(result["connected_id"], result["node_id"])
)
out: dict[str, list[tuple[str, str]]] = {}
for orig in node_ids:
n = self._normalize_node_id(orig)
out[orig] = edges_norm.get(n, [])
return out
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
@ -4491,50 +4582,86 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time
""", """,
"relationships": """ "relationships": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (SELECT id as chunk_id
SELECT id as chunk_id FROM LIGHTRAG_VDB_CHUNKS
FROM LIGHTRAG_VDB_CHUNKS WHERE $2
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
) )
SELECT r.source_id as src_id, r.target_id as tgt_id, , rc AS (
EXTRACT(EPOCH FROM r.create_time)::BIGINT as created_at SELECT array_agg(chunk_id) AS chunk_arr
FROM LIGHTRAG_VDB_RELATION r FROM relevant_chunks
JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids) ), cand AS (
WHERE r.workspace = $1 SELECT
AND r.content_vector <=> '[{embedding_string}]'::vector < $3 r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector FROM LIGHTRAG_VDB_RELATION r
LIMIT $4 WHERE r.workspace = $1
""", ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50)
)
SELECT c.src_id,
c.tgt_id,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""",
"entities": """ "entities": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (SELECT id as chunk_id
SELECT id as chunk_id FROM LIGHTRAG_VDB_CHUNKS
FROM LIGHTRAG_VDB_CHUNKS WHERE $2
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
) )
SELECT e.entity_name, , rc AS (
EXTRACT(EPOCH FROM e.create_time)::BIGINT as created_at SELECT array_agg(chunk_id) AS chunk_arr
FROM LIGHTRAG_VDB_ENTITY e FROM relevant_chunks
JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids) ), cand AS (
WHERE e.workspace = $1 SELECT
AND e.content_vector <=> '[{embedding_string}]'::vector < $3 e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector FROM LIGHTRAG_VDB_ENTITY e
LIMIT $4 WHERE e.workspace = $1
""", ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50)
)
SELECT c.entity_name,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY c.dist, c.id
LIMIT $4;
""",
"chunks": """ "chunks": """
WITH relevant_chunks AS ( WITH relevant_chunks AS (SELECT id as chunk_id
SELECT id as chunk_id FROM LIGHTRAG_VDB_CHUNKS
FROM LIGHTRAG_VDB_CHUNKS WHERE $2
WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar [])
) )
SELECT id, content, file_path, , rc AS (
EXTRACT(EPOCH FROM create_time)::BIGINT as created_at SELECT array_agg(chunk_id) AS chunk_arr
FROM LIGHTRAG_VDB_CHUNKS FROM relevant_chunks
WHERE workspace = $1 ), cand AS (
AND id IN (SELECT chunk_id FROM relevant_chunks) SELECT
AND content_vector <=> '[{embedding_string}]'::vector < $3 id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist
ORDER BY content_vector <=> '[{embedding_string}]'::vector FROM LIGHTRAG_VDB_CHUNKS
LIMIT $4 WHERE workspace = $1
""", ORDER BY content_vector <=> '[{embedding_string}]'::vector
LIMIT ($4 * 50)
)
SELECT c.id,
c.content,
c.file_path,
EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at
FROM cand c
JOIN rc ON TRUE
WHERE c.dist < $3
AND c.id = ANY (rc.chunk_arr)
ORDER BY c.dist, c.id
LIMIT $4;
""",
# DROP tables # DROP tables
"drop_specifiy_table_workspace": """ "drop_specifiy_table_workspace": """
DELETE FROM {table_name} WHERE workspace=$1 DELETE FROM {table_name} WHERE workspace=$1