feat(postgres): add bulk operations and health check

- Implement bulk upsert_nodes/edges via UNWIND reducing round trips
- Add health_check for graph connectivity and AGE catalog status
- Switch to parameterized queries preventing Cypher injection
- Fix node ID sanitization: strip control chars, escape quotes
This commit is contained in:
clssck 2025-12-03 18:19:26 +00:00
parent c5f230a30c
commit a6b87df758
7 changed files with 388 additions and 172 deletions

View file

@ -68,6 +68,11 @@ FROM python:3.12-slim
WORKDIR /app
# Add curl for runtime healthchecks and simple diagnostics
RUN apt-get update \
&& apt-get install -y --no-install-recommends curl \
&& rm -rf /var/lib/apt/lists/*
# Install uv for package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv

View file

@ -54,8 +54,12 @@ services:
volumes:
- ./data/rag_storage_test:/app/data/rag_storage
- ./data/inputs_test:/app/data/inputs
- ./lightrag:/app/lightrag # Mount source for live reload
# Live reload: Use absolute host path for Docker-in-Docker compatibility (Coder)
- /var/lib/docker/volumes/coder-shared-projects-optimized/_data/LightRAG/lightrag:/app/lightrag
environment:
# Live reload: PYTHONPATH makes mounted /app/lightrag take precedence over site-packages
- PYTHONPATH=/app
# Server
- HOST=0.0.0.0
- PORT=9621
@ -120,7 +124,8 @@ services:
entrypoint: []
command:
- python
- /app/lightrag/api/run_with_gunicorn.py
- -m
- lightrag.api.run_with_gunicorn
- --workers
- "8"
- --llm-binding

View file

@ -1266,6 +1266,9 @@ def create_app(args):
else:
auth_mode = "enabled"
# Optional graph health probe (lightweight) - Using unified health_check interface
graph_health = await rag.chunk_entity_relation_graph.health_check()
# Cleanup expired keyed locks and get status
keyed_lock_info = cleanup_keyed_lock()
@ -1319,6 +1322,7 @@ def create_app(args):
"api_version": api_version_display,
"webui_title": webui_title,
"webui_description": webui_description,
"graph": graph_health,
}
except Exception as e:
logger.error(f"Error getting health status: {str(e)}")

View file

@ -187,6 +187,14 @@ class StorageNameSpace(ABC):
async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing"""
async def health_check(self, max_retries: int = 3) -> dict[str, Any]:
"""Check the health status of the storage
Returns:
dict[str, Any]: Health status dictionary with at least 'status' field
"""
return {"status": "healthy"}
@abstractmethod
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
@ -527,6 +535,22 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else []
return result
async def upsert_nodes_bulk(
self, nodes: list[tuple[str, dict[str, str]]], batch_size: int = 500
) -> None:
"""Default bulk helper; storage backends can override for batching."""
for node_id, node_data in nodes:
await self.upsert_node(node_id, node_data)
async def upsert_edges_bulk(
self,
edges: list[tuple[str, str, dict[str, str]]],
batch_size: int = 500,
) -> None:
"""Default bulk helper; storage backends can override for batching."""
for src, tgt, edge_data in edges:
await self.upsert_edge(src, tgt, edge_data)
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.

View file

@ -3389,32 +3389,31 @@ class PGDocStatusStorage(DocStatusStorage):
error_msg = EXCLUDED.error_msg,
created_at = EXCLUDED.created_at,
updated_at = EXCLUDED.updated_at"""
batch_data = []
for k, v in data.items():
# Remove timezone information, store utc time in db
created_at = parse_datetime(v.get("created_at"))
updated_at = parse_datetime(v.get("updated_at"))
# chunks_count, chunks_list, track_id, metadata, and error_msg are optional
await self.db.execute(
sql,
{
"workspace": self.workspace,
"id": k,
"content_summary": v["content_summary"],
"content_length": v["content_length"],
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"],
"file_path": v["file_path"],
"chunks_list": json.dumps(v.get("chunks_list", [])),
"track_id": v.get("track_id"), # Add track_id support
"metadata": json.dumps(
v.get("metadata", {})
), # Add metadata support
"error_msg": v.get("error_msg"), # Add error_msg support
"created_at": created_at, # Use the converted datetime object
"updated_at": updated_at, # Use the converted datetime object
},
)
batch_data.append((
self.workspace,
k,
v["content_summary"],
v["content_length"],
v["chunks_count"] if "chunks_count" in v else -1,
v["status"],
v["file_path"],
json.dumps(v.get("chunks_list", [])),
v.get("track_id"),
json.dumps(v.get("metadata", {})),
v.get("error_msg"),
created_at,
updated_at
))
if batch_data:
await self.db.executemany(sql, batch_data)
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
@ -3487,19 +3486,23 @@ class PGGraphStorage(BaseGraphStorage):
@staticmethod
def _normalize_node_id(node_id: str) -> str:
"""
Normalize node ID to ensure special characters are properly handled in Cypher queries.
"""Best-effort sanitization for identifiers we interpolate into Cypher.
Args:
node_id: The original node ID
Returns:
Normalized node ID suitable for Cypher queries
This avoids common parse errors without altering the semantic value.
Control chars are stripped, quotes/backticks are escaped, and we keep
the result ASCII-only to match server expectations.
"""
# Escape backslashes
normalized_id = node_id
normalized_id = normalized_id.replace("\\", "\\\\")
normalized_id = normalized_id.replace('"', '\\"')
# Drop control characters that can break AGE parsing
normalized_id = re.sub(r"[\x00-\x1F]", "", node_id)
# Escape characters that matter for the interpolated Cypher literal
normalized_id = normalized_id.replace("\\", "\\\\") # backslash
normalized_id = normalized_id.replace('"', '\\"') # double quote
normalized_id = normalized_id.replace("`", "\\`") # backtick
# Keep it compact and ASCII to avoid encoding surprises
normalized_id = normalized_id.encode("ascii", "ignore").decode("ascii")
return normalized_id
async def initialize(self):
@ -3549,6 +3552,7 @@ class PGGraphStorage(BaseGraphStorage):
f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)',
f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)',
f'CREATE UNIQUE INDEX CONCURRENTLY {self.graph_name}_entity_id_unique ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx',
]
@ -3563,10 +3567,44 @@ class PGGraphStorage(BaseGraphStorage):
graph_name=self.graph_name,
)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def health_check(self, max_retries: int = 3) -> dict[str, Any]:
"""Check Postgres graph connectivity and status"""
graph_health = {"status": "unknown"}
try:
# Base connectivity
await self.db.execute("SELECT 1")
# AGE catalog available?
graphs = await self.db.query(
"SELECT name FROM ag_catalog.ag_graph",
multirows=True,
with_age=False,
)
graph_exists = False
if graphs:
names = [g.get("name") for g in graphs if "name" in g]
if self.graph_name in names:
graph_exists = True
# Basic Cypher query test on this workspace's graph
if graph_exists:
await self.db.query(
f"SELECT * FROM cypher('{self.graph_name}', $$ RETURN 1 $$) AS (one agtype)",
with_age=True,
graph_name=self.graph_name,
)
graph_health = {
"status": "healthy",
"graphs": graphs if graphs is not None else [],
"workspace_graph": self.graph_name if graph_exists else None,
}
except Exception as exc:
graph_health = {"status": "unhealthy", "detail": str(exc)}
logger.debug(f"Graph health check failed: {exc}")
return graph_health
async def index_done_callback(self) -> None:
# PG handles persistence automatically
@ -3685,31 +3723,9 @@ class PGGraphStorage(BaseGraphStorage):
return d
@staticmethod
def _format_properties(
properties: dict[str, Any], _id: Union[str, None] = None
) -> str:
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert/merge statement.
Args:
properties (dict[str,str]): a dictionary containing node/edge properties
_id (Union[str, None]): the id of the node or None if none exists
Returns:
str: the properties dictionary as a properly formatted string
"""
props = []
# wrap property key in backticks to escape
for k, v in properties.items():
prop = f"`{k}`: {json.dumps(v)}"
props.append(prop)
if _id is not None and "id" not in properties:
props.append(
f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}"
)
return "{" + ", ".join(props) + "}"
def _format_properties(self, data: dict[str, Any]) -> str:
# Deprecated: Use parameterized queries instead
return ""
async def _query(
self,
@ -3717,6 +3733,7 @@ class PGGraphStorage(BaseGraphStorage):
readonly: bool = True,
upsert: bool = False,
params: dict[str, Any] | None = None,
timeout: float | None = None,
) -> list[dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
@ -3740,6 +3757,7 @@ class PGGraphStorage(BaseGraphStorage):
else:
data = await self.db.execute(
query,
data=params,
upsert=upsert,
with_age=True,
graph_name=self.graph_name,
@ -3853,16 +3871,13 @@ class PGGraphStorage(BaseGraphStorage):
"""
label = self._normalize_node_id(source_node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: $1}})
OPTIONAL MATCH (n)-[]-(connected:base)
RETURN n.entity_id AS source_id, connected.entity_id AS connected_id
$$) AS (source_id text, connected_id text)""" % (
self.graph_name,
label,
)
$$, $1) AS (source_id text, connected_id text)"""
results = await self._query(query)
results = await self._query(query, params={"node_id": label})
edges = []
for record in results:
source_id = record["source_id"]
@ -3892,20 +3907,24 @@ class PGGraphStorage(BaseGraphStorage):
)
label = self._normalize_node_id(node_id)
properties = self._format_properties(node_data)
# Use parameterized query to prevent injection
cy_params = {
"params": json.dumps(
{"node_id": label, "properties": node_data}, ensure_ascii=False
)
}
query = """SELECT * FROM cypher('%s', $$
MERGE (n:base {entity_id: "%s"})
SET n += %s
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MERGE (n:base {{entity_id: $node_id}})
SET n += $properties
RETURN n
$$) AS (n agtype)""" % (
self.graph_name,
label,
properties,
)
$$, $1::agtype) AS (n agtype)"""
try:
await self._query(query, readonly=False, upsert=True)
await self._query(
query, readonly=False, upsert=True, params=cy_params
)
except Exception:
logger.error(
@ -3931,26 +3950,32 @@ class PGGraphStorage(BaseGraphStorage):
"""
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
edge_properties = self._format_properties(edge_data)
# Use parameterized query for security
cy_params = {
"params": json.dumps(
{
"src_id": src_label,
"tgt_id": tgt_label,
"properties": edge_data
},
ensure_ascii=False
)
}
query = """SELECT * FROM cypher('%s', $$
MATCH (source:base {entity_id: "%s"})
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (source:base {{entity_id: $src_id}})
WITH source
MATCH (target:base {entity_id: "%s"})
MATCH (target:base {{entity_id: $tgt_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += %s
SET r += %s
SET r += $properties
RETURN r
$$) AS (r agtype)""" % (
self.graph_name,
src_label,
tgt_label,
edge_properties,
edge_properties, # https://github.com/HKUDS/LightRAG/issues/1438#issuecomment-2826000195
)
$$, $1::agtype) AS (r agtype)"""
try:
await self._query(query, readonly=False, upsert=True)
await self._query(
query, readonly=False, upsert=True, params=cy_params
)
except Exception:
logger.error(
@ -3958,6 +3983,99 @@ class PGGraphStorage(BaseGraphStorage):
)
raise
async def upsert_nodes_bulk(
self, nodes: list[tuple[str, dict[str, str]]], batch_size: int = 500
) -> None:
"""Bulk upsert nodes using UNWIND for fewer round-trips."""
if not nodes:
return
if batch_size <= 0 or batch_size > 5000:
raise ValueError("batch_size must be between 1 and 5000 for bulk node upserts")
cleaned = []
for node_id, data in nodes:
label = self._normalize_node_id(node_id)
if "entity_id" not in data:
raise ValueError("PostgreSQL: node properties must contain an 'entity_id' field")
node_props = dict(data)
node_props["entity_id"] = label
cleaned.append({"entity_id": label, "properties": node_props})
for i in range(0, len(cleaned), batch_size):
batch = cleaned[i : i + batch_size]
cy_params = {"params": json.dumps({"nodes": batch}, ensure_ascii=False)}
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND $nodes AS n
MERGE (v:base {{entity_id: n.entity_id}})
SET v += n.properties
RETURN count(v) AS upserted
$$, $1::agtype) AS (upserted bigint)"""
try:
await self._query(
query, readonly=False, upsert=True, params=cy_params, timeout=DEFAULT_LLM_TIMEOUT
)
except Exception:
logger.error(
f"[{self.workspace}] POSTGRES, bulk upsert_node failed on batch starting with `{batch[0]['entity_id']}`"
)
raise
async def upsert_edges_bulk(
self,
edges: list[tuple[str, str, dict[str, str]]],
batch_size: int = 500,
) -> None:
"""Bulk upsert edges using UNWIND; keeps current undirected semantics."""
if not edges:
return
if batch_size <= 0 or batch_size > 5000:
raise ValueError("batch_size must be between 1 and 5000 for bulk edge upserts")
cleaned = []
for src, tgt, props in edges:
src_label = self._normalize_node_id(src)
tgt_label = self._normalize_node_id(tgt)
cleaned.append(
{
"src": src_label,
"tgt": tgt_label,
"properties": dict(props),
}
)
for i in range(0, len(cleaned), batch_size):
batch = cleaned[i : i + batch_size]
cy_params = {"params": json.dumps({"edges": batch}, ensure_ascii=False)}
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND $edges AS e
MATCH (source:base {{entity_id: e.src}})
WITH source, e
MATCH (target:base {{entity_id: e.tgt}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += e.properties
SET r += e.properties
RETURN count(r) AS upserted
$$, $1::agtype) AS (upserted bigint)"""
try:
await self._query(
query,
readonly=False,
upsert=True,
params=cy_params,
timeout=DEFAULT_LLM_TIMEOUT,
)
except Exception:
logger.error(
f"[{self.workspace}] POSTGRES, bulk upsert_edge failed on batch starting with `{batch[0]['src']}`-`{batch[0]['tgt']}`"
)
raise
async def delete_node(self, node_id: str) -> None:
"""
Delete a node from the graph.
@ -3967,13 +4085,13 @@ class PGGraphStorage(BaseGraphStorage):
"""
label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: $1}})
DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, label)
$$, $1) AS (n agtype)"""
try:
await self._query(query, readonly=False)
await self._query(query, readonly=False, params={"node_id": label})
except Exception as e:
logger.error(f"[{self.workspace}] Error during node deletion: {e}")
raise
@ -3986,16 +4104,21 @@ class PGGraphStorage(BaseGraphStorage):
node_ids (list[str]): A list of node IDs to remove.
"""
node_ids = [self._normalize_node_id(node_id) for node_id in node_ids]
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
if not node_ids:
return
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base)
WHERE n.entity_id IN [%s]
unique_ids = list(dict.fromkeys(node_ids))
cy_params = {"params": json.dumps({"node_ids": unique_ids}, ensure_ascii=False)}
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND $node_ids AS node_id
MATCH (n:base {{entity_id: node_id}})
DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
$$, $1::agtype) AS (n agtype)"""
try:
await self._query(query, readonly=False)
await self._query(query, readonly=False, params=cy_params)
logger.debug(f"[{self.workspace}] Removed {len(unique_ids)} nodes from graph")
except Exception as e:
logger.error(f"[{self.workspace}] Error during node removal: {e}")
raise
@ -4007,23 +4130,38 @@ class PGGraphStorage(BaseGraphStorage):
Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
"""
if not edges:
return
cleaned_edges: list[tuple[str, str]] = []
seen: set[tuple[str, str]] = set()
for source, target in edges:
src_label = self._normalize_node_id(source)
tgt_label = self._normalize_node_id(target)
pair = (self._normalize_node_id(source), self._normalize_node_id(target))
if pair not in seen:
seen.add(pair)
cleaned_edges.append(pair)
query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
if not cleaned_edges:
return
literal_pairs = ", ".join(
[f"{{src: {json.dumps(src)}, tgt: {json.dumps(tgt)}}}" for src, tgt in cleaned_edges]
)
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{literal_pairs}] AS pair
MATCH (a:base {{entity_id: pair.src}})-[r]-(b:base {{entity_id: pair.tgt}})
DELETE r
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
$$) AS (r agtype)"""
try:
await self._query(query, readonly=False)
logger.debug(
f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
)
except Exception as e:
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise
try:
await self._query(query, readonly=False)
logger.debug(
f"[{self.workspace}] Deleted {len(cleaned_edges)} edges (undirected)"
)
except Exception as e:
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise
async def get_nodes_batch(
self, node_ids: list[str], batch_size: int = 1000
@ -4311,8 +4449,10 @@ class PGGraphStorage(BaseGraphStorage):
pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)}
forward_results = await self._query(sql_fwd, params=pg_params)
backward_results = await self._query(sql_bwd, params=pg_params)
forward_results, backward_results = await asyncio.gather(
self._query(sql_fwd, params=pg_params),
self._query(sql_bwd, params=pg_params)
)
for result in forward_results:
if result["source"] and result["target"] and result["edge_properties"]:
@ -4376,31 +4516,26 @@ class PGGraphStorage(BaseGraphStorage):
for i in range(0, len(unique_ids), batch_size):
batch = unique_ids[i : i + batch_size]
# Format node IDs for the query
formatted_ids = ", ".join([f'"{n}"' for n in batch])
cy_params = {"params": json.dumps({"node_ids": batch}, ensure_ascii=False)}
outgoing_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND $node_ids AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n:base)-[]->(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
$$, $1::agtype) AS (node_id text, connected_id text)"""
incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND $node_ids AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n:base)<-[]-(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
$$, $1::agtype) AS (node_id text, connected_id text)"""
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
outgoing_results, incoming_results = await asyncio.gather(
self._query(outgoing_query, params=cy_params),
self._query(incoming_query, params=cy_params)
)
for result in outgoing_results:
if result["node_id"] and result["connected_id"]:
@ -4428,15 +4563,14 @@ class PGGraphStorage(BaseGraphStorage):
Returns:
list[str]: A list of all labels in the graph.
"""
query = (
"""SELECT * FROM cypher('%s', $$
MATCH (n:base)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY n.entity_id
$$) AS (label text)"""
% self.graph_name
)
# Use native SQL for better performance
query = f"""
SELECT DISTINCT
(ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text AS label
FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]) IS NOT NULL
ORDER BY label
"""
results = await self._query(query)
labels = []
@ -4471,12 +4605,12 @@ class PGGraphStorage(BaseGraphStorage):
# Get starting node data
label = self._normalize_node_id(node_label)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: $1}})
RETURN id(n) as node_id, n
$$) AS (node_id bigint, n agtype)""" % (self.graph_name, label)
$$, $1) AS (node_id bigint, n agtype)"""
node_result = await self._query(query)
node_result = await self._query(query, params={"node_id": label})
if not node_result or not node_result[0].get("n"):
return result
@ -4525,14 +4659,12 @@ class PGGraphStorage(BaseGraphStorage):
continue
# Prepare node IDs list
node_ids = [node.labels[0] for node in current_level_nodes]
formatted_ids = ", ".join(
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
)
node_ids = [self._normalize_node_id(node.labels[0]) for node in current_level_nodes]
cy_params = {"params": json.dumps({"node_ids": node_ids}, ensure_ascii=False)}
# Construct batch query for outgoing edges
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id
UNWIND $node_ids AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)-[r]->(neighbor:base)
RETURN node_id AS current_id,
@ -4543,12 +4675,12 @@ class PGGraphStorage(BaseGraphStorage):
r,
neighbor,
true AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
$$, $1::agtype) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Construct batch query for incoming edges
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id
UNWIND $node_ids AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
RETURN node_id AS current_id,
@ -4559,12 +4691,14 @@ class PGGraphStorage(BaseGraphStorage):
r,
neighbor,
false AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
$$, $1::agtype) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Execute queries
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
# Execute queries concurrently
outgoing_results, incoming_results = await asyncio.gather(
self._query(outgoing_query, params=cy_params),
self._query(incoming_query, params=cy_params)
)
# Combine results
neighbors = outgoing_results + incoming_results
@ -4654,17 +4788,15 @@ class PGGraphStorage(BaseGraphStorage):
# Add db_degree to all nodes via bulk query
if result.nodes:
entity_ids = [node.labels[0] for node in result.nodes]
formatted_ids = ", ".join(
[f'"{self._normalize_node_id(eid)}"' for eid in entity_ids]
)
entity_ids = [self._normalize_node_id(node.labels[0]) for node in result.nodes]
degree_params = {"params": json.dumps({"node_ids": entity_ids}, ensure_ascii=False)}
degree_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS entity_id
UNWIND $node_ids AS entity_id
MATCH (n:base {{entity_id: entity_id}})
OPTIONAL MATCH (n)-[r]-()
RETURN entity_id, count(r) as degree
$$) AS (entity_id text, degree bigint)"""
degree_results = await self._query(degree_query)
$$, $1::agtype) AS (entity_id text, degree bigint)"""
degree_results = await self._query(degree_query, params=degree_params)
degree_map = {
row["entity_id"]: int(row["degree"]) for row in degree_results
}
@ -4753,17 +4885,17 @@ class PGGraphStorage(BaseGraphStorage):
)
if node_ids:
formatted_ids = ", ".join(node_ids)
cy_params = {"params": json.dumps({"node_ids": [int(n) for n in node_ids]}, ensure_ascii=False)}
# Construct batch query for subgraph within max_nodes
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{formatted_ids}] AS node_ids
WITH $node_ids AS node_ids
MATCH (a)
WHERE id(a) IN node_ids
OPTIONAL MATCH (a)-[r]->(b)
WHERE id(b) IN node_ids
RETURN a, r, b
$$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
results = await self._query(query)
$$, $1::agtype) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
results = await self._query(query, params=cy_params)
# Process query results, deduplicate nodes and edges
nodes_dict = {}

View file

@ -629,6 +629,9 @@ async def _handle_single_relationship_extraction(
)
edge_keywords = edge_keywords.replace("", ",")
# Derive a relationship label from the first keyword (fallback to description later)
relationship_label = edge_keywords.split(",")[0].strip() if edge_keywords else ""
# Process relationship description with same cleaning pipeline
edge_description = sanitize_and_normalize_extracted_text(record_attributes[4])
@ -645,6 +648,8 @@ async def _handle_single_relationship_extraction(
weight=weight,
description=edge_description,
keywords=edge_keywords,
relationship=relationship_label,
type=relationship_label,
source_id=edge_source_id,
file_path=file_path,
timestamp=timestamp,
@ -2337,6 +2342,8 @@ async def _merge_edges_then_upsert(
already_source_ids = []
already_description = []
already_keywords = []
already_relationships = []
already_types = []
already_file_paths = []
# 1. Get existing edge data from graph storage
@ -2373,6 +2380,20 @@ async def _merge_edges_then_upsert(
)
)
if already_edge.get("relationship") is not None:
already_relationships.extend(
split_string_by_multi_markers(
already_edge["relationship"], [GRAPH_FIELD_SEP, ","]
)
)
if already_edge.get("type") is not None:
already_types.extend(
split_string_by_multi_markers(
already_edge["type"], [GRAPH_FIELD_SEP, ","]
)
)
new_source_ids = [dp["source_id"] for dp in edges_data if dp.get("source_id")]
storage_key = make_relation_chunk_key(src_id, tgt_id)
@ -2474,6 +2495,23 @@ async def _merge_edges_then_upsert(
# Join all unique keywords with commas
keywords = ",".join(sorted(all_keywords))
# 6.3 Finalize relationship/type labels from explicit field or fallback to keywords
rel_labels = set()
for edge in edges_data:
if edge.get("relationship"):
rel_labels.update(k.strip() for k in edge["relationship"].split(",") if k.strip())
rel_labels.update(k.strip() for k in already_relationships if k.strip())
relationship_label = ",".join(sorted(rel_labels))
if not relationship_label and keywords:
relationship_label = keywords.split(",")[0]
type_labels = set()
for edge in edges_data:
if edge.get("type"):
type_labels.update(k.strip() for k in edge["type"].split(",") if k.strip())
type_labels.update(k.strip() for k in already_types if k.strip())
type_label = ",".join(sorted(type_labels)) or relationship_label
# 7. Deduplicate by description, keeping first occurrence in the same document
unique_edges = {}
for dp in edges_data:
@ -2785,6 +2823,8 @@ async def _merge_edges_then_upsert(
weight=weight,
description=description,
keywords=keywords,
relationship=relationship_label,
type=type_label,
source_id=source_id,
file_path=file_path,
created_at=edge_created_at,
@ -2797,6 +2837,8 @@ async def _merge_edges_then_upsert(
tgt_id=tgt_id,
description=description,
keywords=keywords,
relationship=relationship_label,
type=type_label,
source_id=source_id,
file_path=file_path,
created_at=edge_created_at,
@ -2822,6 +2864,8 @@ async def _merge_edges_then_upsert(
rel_vdb_id: {
"src_id": src_id,
"tgt_id": tgt_id,
"relationship": relationship_label,
"type": type_label,
"source_id": source_id,
"content": rel_content,
"keywords": keywords,

View file

@ -1332,6 +1332,8 @@ async def _merge_entities_impl(
{
"description": "concatenate",
"keywords": "join_unique_comma",
"relationship": "join_unique_comma",
"type": "join_unique_comma",
"source_id": "join_unique",
"file_path": "join_unique",
"weight": "max",