feat(postgres): add bulk operations and health check
- Implement bulk upsert_nodes/edges via UNWIND reducing round trips - Add health_check for graph connectivity and AGE catalog status - Switch to parameterized queries preventing Cypher injection - Fix node ID sanitization: strip control chars, escape quotes
This commit is contained in:
parent
c5f230a30c
commit
a6b87df758
7 changed files with 388 additions and 172 deletions
|
|
@ -68,6 +68,11 @@ FROM python:3.12-slim
|
|||
|
||||
WORKDIR /app
|
||||
|
||||
# Add curl for runtime healthchecks and simple diagnostics
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
|
|
|
|||
|
|
@ -54,8 +54,12 @@ services:
|
|||
volumes:
|
||||
- ./data/rag_storage_test:/app/data/rag_storage
|
||||
- ./data/inputs_test:/app/data/inputs
|
||||
- ./lightrag:/app/lightrag # Mount source for live reload
|
||||
# Live reload: Use absolute host path for Docker-in-Docker compatibility (Coder)
|
||||
- /var/lib/docker/volumes/coder-shared-projects-optimized/_data/LightRAG/lightrag:/app/lightrag
|
||||
environment:
|
||||
# Live reload: PYTHONPATH makes mounted /app/lightrag take precedence over site-packages
|
||||
- PYTHONPATH=/app
|
||||
|
||||
# Server
|
||||
- HOST=0.0.0.0
|
||||
- PORT=9621
|
||||
|
|
@ -120,7 +124,8 @@ services:
|
|||
entrypoint: []
|
||||
command:
|
||||
- python
|
||||
- /app/lightrag/api/run_with_gunicorn.py
|
||||
- -m
|
||||
- lightrag.api.run_with_gunicorn
|
||||
- --workers
|
||||
- "8"
|
||||
- --llm-binding
|
||||
|
|
|
|||
|
|
@ -1266,6 +1266,9 @@ def create_app(args):
|
|||
else:
|
||||
auth_mode = "enabled"
|
||||
|
||||
# Optional graph health probe (lightweight) - Using unified health_check interface
|
||||
graph_health = await rag.chunk_entity_relation_graph.health_check()
|
||||
|
||||
# Cleanup expired keyed locks and get status
|
||||
keyed_lock_info = cleanup_keyed_lock()
|
||||
|
||||
|
|
@ -1319,6 +1322,7 @@ def create_app(args):
|
|||
"api_version": api_version_display,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
"graph": graph_health,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting health status: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -187,6 +187,14 @@ class StorageNameSpace(ABC):
|
|||
async def index_done_callback(self) -> None:
|
||||
"""Commit the storage operations after indexing"""
|
||||
|
||||
async def health_check(self, max_retries: int = 3) -> dict[str, Any]:
|
||||
"""Check the health status of the storage
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Health status dictionary with at least 'status' field
|
||||
"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
@abstractmethod
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
|
|
@ -527,6 +535,22 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||
result[node_id] = edges if edges is not None else []
|
||||
return result
|
||||
|
||||
async def upsert_nodes_bulk(
|
||||
self, nodes: list[tuple[str, dict[str, str]]], batch_size: int = 500
|
||||
) -> None:
|
||||
"""Default bulk helper; storage backends can override for batching."""
|
||||
for node_id, node_data in nodes:
|
||||
await self.upsert_node(node_id, node_data)
|
||||
|
||||
async def upsert_edges_bulk(
|
||||
self,
|
||||
edges: list[tuple[str, str, dict[str, str]]],
|
||||
batch_size: int = 500,
|
||||
) -> None:
|
||||
"""Default bulk helper; storage backends can override for batching."""
|
||||
for src, tgt, edge_data in edges:
|
||||
await self.upsert_edge(src, tgt, edge_data)
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""Insert a new node or update an existing node in the graph.
|
||||
|
|
|
|||
|
|
@ -3389,32 +3389,31 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
error_msg = EXCLUDED.error_msg,
|
||||
created_at = EXCLUDED.created_at,
|
||||
updated_at = EXCLUDED.updated_at"""
|
||||
|
||||
batch_data = []
|
||||
for k, v in data.items():
|
||||
# Remove timezone information, store utc time in db
|
||||
created_at = parse_datetime(v.get("created_at"))
|
||||
updated_at = parse_datetime(v.get("updated_at"))
|
||||
|
||||
# chunks_count, chunks_list, track_id, metadata, and error_msg are optional
|
||||
await self.db.execute(
|
||||
sql,
|
||||
{
|
||||
"workspace": self.workspace,
|
||||
"id": k,
|
||||
"content_summary": v["content_summary"],
|
||||
"content_length": v["content_length"],
|
||||
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
||||
"status": v["status"],
|
||||
"file_path": v["file_path"],
|
||||
"chunks_list": json.dumps(v.get("chunks_list", [])),
|
||||
"track_id": v.get("track_id"), # Add track_id support
|
||||
"metadata": json.dumps(
|
||||
v.get("metadata", {})
|
||||
), # Add metadata support
|
||||
"error_msg": v.get("error_msg"), # Add error_msg support
|
||||
"created_at": created_at, # Use the converted datetime object
|
||||
"updated_at": updated_at, # Use the converted datetime object
|
||||
},
|
||||
)
|
||||
batch_data.append((
|
||||
self.workspace,
|
||||
k,
|
||||
v["content_summary"],
|
||||
v["content_length"],
|
||||
v["chunks_count"] if "chunks_count" in v else -1,
|
||||
v["status"],
|
||||
v["file_path"],
|
||||
json.dumps(v.get("chunks_list", [])),
|
||||
v.get("track_id"),
|
||||
json.dumps(v.get("metadata", {})),
|
||||
v.get("error_msg"),
|
||||
created_at,
|
||||
updated_at
|
||||
))
|
||||
|
||||
if batch_data:
|
||||
await self.db.executemany(sql, batch_data)
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
|
|
@ -3487,19 +3486,23 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
|
||||
@staticmethod
|
||||
def _normalize_node_id(node_id: str) -> str:
|
||||
"""
|
||||
Normalize node ID to ensure special characters are properly handled in Cypher queries.
|
||||
"""Best-effort sanitization for identifiers we interpolate into Cypher.
|
||||
|
||||
Args:
|
||||
node_id: The original node ID
|
||||
|
||||
Returns:
|
||||
Normalized node ID suitable for Cypher queries
|
||||
This avoids common parse errors without altering the semantic value.
|
||||
Control chars are stripped, quotes/backticks are escaped, and we keep
|
||||
the result ASCII-only to match server expectations.
|
||||
"""
|
||||
# Escape backslashes
|
||||
normalized_id = node_id
|
||||
normalized_id = normalized_id.replace("\\", "\\\\")
|
||||
normalized_id = normalized_id.replace('"', '\\"')
|
||||
|
||||
# Drop control characters that can break AGE parsing
|
||||
normalized_id = re.sub(r"[\x00-\x1F]", "", node_id)
|
||||
|
||||
# Escape characters that matter for the interpolated Cypher literal
|
||||
normalized_id = normalized_id.replace("\\", "\\\\") # backslash
|
||||
normalized_id = normalized_id.replace('"', '\\"') # double quote
|
||||
normalized_id = normalized_id.replace("`", "\\`") # backtick
|
||||
|
||||
# Keep it compact and ASCII to avoid encoding surprises
|
||||
normalized_id = normalized_id.encode("ascii", "ignore").decode("ascii")
|
||||
return normalized_id
|
||||
|
||||
async def initialize(self):
|
||||
|
|
@ -3549,6 +3552,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)',
|
||||
f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
|
||||
f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)',
|
||||
f'CREATE UNIQUE INDEX CONCURRENTLY {self.graph_name}_entity_id_unique ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
|
||||
f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx',
|
||||
]
|
||||
|
||||
|
|
@ -3563,10 +3567,44 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
graph_name=self.graph_name,
|
||||
)
|
||||
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
async def health_check(self, max_retries: int = 3) -> dict[str, Any]:
|
||||
"""Check Postgres graph connectivity and status"""
|
||||
graph_health = {"status": "unknown"}
|
||||
try:
|
||||
# Base connectivity
|
||||
await self.db.execute("SELECT 1")
|
||||
|
||||
# AGE catalog available?
|
||||
graphs = await self.db.query(
|
||||
"SELECT name FROM ag_catalog.ag_graph",
|
||||
multirows=True,
|
||||
with_age=False,
|
||||
)
|
||||
|
||||
graph_exists = False
|
||||
if graphs:
|
||||
names = [g.get("name") for g in graphs if "name" in g]
|
||||
if self.graph_name in names:
|
||||
graph_exists = True
|
||||
|
||||
# Basic Cypher query test on this workspace's graph
|
||||
if graph_exists:
|
||||
await self.db.query(
|
||||
f"SELECT * FROM cypher('{self.graph_name}', $$ RETURN 1 $$) AS (one agtype)",
|
||||
with_age=True,
|
||||
graph_name=self.graph_name,
|
||||
)
|
||||
|
||||
graph_health = {
|
||||
"status": "healthy",
|
||||
"graphs": graphs if graphs is not None else [],
|
||||
"workspace_graph": self.graph_name if graph_exists else None,
|
||||
}
|
||||
except Exception as exc:
|
||||
graph_health = {"status": "unhealthy", "detail": str(exc)}
|
||||
logger.debug(f"Graph health check failed: {exc}")
|
||||
|
||||
return graph_health
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# PG handles persistence automatically
|
||||
|
|
@ -3685,31 +3723,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def _format_properties(
|
||||
properties: dict[str, Any], _id: Union[str, None] = None
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary of properties to a string representation that
|
||||
can be used in a cypher query insert/merge statement.
|
||||
|
||||
Args:
|
||||
properties (dict[str,str]): a dictionary containing node/edge properties
|
||||
_id (Union[str, None]): the id of the node or None if none exists
|
||||
|
||||
Returns:
|
||||
str: the properties dictionary as a properly formatted string
|
||||
"""
|
||||
props = []
|
||||
# wrap property key in backticks to escape
|
||||
for k, v in properties.items():
|
||||
prop = f"`{k}`: {json.dumps(v)}"
|
||||
props.append(prop)
|
||||
if _id is not None and "id" not in properties:
|
||||
props.append(
|
||||
f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}"
|
||||
)
|
||||
return "{" + ", ".join(props) + "}"
|
||||
def _format_properties(self, data: dict[str, Any]) -> str:
|
||||
# Deprecated: Use parameterized queries instead
|
||||
return ""
|
||||
|
||||
async def _query(
|
||||
self,
|
||||
|
|
@ -3717,6 +3733,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
readonly: bool = True,
|
||||
upsert: bool = False,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
|
|
@ -3740,6 +3757,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
else:
|
||||
data = await self.db.execute(
|
||||
query,
|
||||
data=params,
|
||||
upsert=upsert,
|
||||
with_age=True,
|
||||
graph_name=self.graph_name,
|
||||
|
|
@ -3853,16 +3871,13 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
"""
|
||||
label = self._normalize_node_id(source_node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base {{entity_id: $1}})
|
||||
OPTIONAL MATCH (n)-[]-(connected:base)
|
||||
RETURN n.entity_id AS source_id, connected.entity_id AS connected_id
|
||||
$$) AS (source_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
label,
|
||||
)
|
||||
$$, $1) AS (source_id text, connected_id text)"""
|
||||
|
||||
results = await self._query(query)
|
||||
results = await self._query(query, params={"node_id": label})
|
||||
edges = []
|
||||
for record in results:
|
||||
source_id = record["source_id"]
|
||||
|
|
@ -3892,20 +3907,24 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
)
|
||||
|
||||
label = self._normalize_node_id(node_id)
|
||||
properties = self._format_properties(node_data)
|
||||
|
||||
# Use parameterized query to prevent injection
|
||||
cy_params = {
|
||||
"params": json.dumps(
|
||||
{"node_id": label, "properties": node_data}, ensure_ascii=False
|
||||
)
|
||||
}
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MERGE (n:base {entity_id: "%s"})
|
||||
SET n += %s
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MERGE (n:base {{entity_id: $node_id}})
|
||||
SET n += $properties
|
||||
RETURN n
|
||||
$$) AS (n agtype)""" % (
|
||||
self.graph_name,
|
||||
label,
|
||||
properties,
|
||||
)
|
||||
$$, $1::agtype) AS (n agtype)"""
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False, upsert=True)
|
||||
await self._query(
|
||||
query, readonly=False, upsert=True, params=cy_params
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
|
|
@ -3931,26 +3950,32 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
"""
|
||||
src_label = self._normalize_node_id(source_node_id)
|
||||
tgt_label = self._normalize_node_id(target_node_id)
|
||||
edge_properties = self._format_properties(edge_data)
|
||||
|
||||
# Use parameterized query for security
|
||||
cy_params = {
|
||||
"params": json.dumps(
|
||||
{
|
||||
"src_id": src_label,
|
||||
"tgt_id": tgt_label,
|
||||
"properties": edge_data
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
}
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (source:base {entity_id: "%s"})
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (source:base {{entity_id: $src_id}})
|
||||
WITH source
|
||||
MATCH (target:base {entity_id: "%s"})
|
||||
MATCH (target:base {{entity_id: $tgt_id}})
|
||||
MERGE (source)-[r:DIRECTED]-(target)
|
||||
SET r += %s
|
||||
SET r += %s
|
||||
SET r += $properties
|
||||
RETURN r
|
||||
$$) AS (r agtype)""" % (
|
||||
self.graph_name,
|
||||
src_label,
|
||||
tgt_label,
|
||||
edge_properties,
|
||||
edge_properties, # https://github.com/HKUDS/LightRAG/issues/1438#issuecomment-2826000195
|
||||
)
|
||||
$$, $1::agtype) AS (r agtype)"""
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False, upsert=True)
|
||||
await self._query(
|
||||
query, readonly=False, upsert=True, params=cy_params
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
|
|
@ -3958,6 +3983,99 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
)
|
||||
raise
|
||||
|
||||
async def upsert_nodes_bulk(
|
||||
self, nodes: list[tuple[str, dict[str, str]]], batch_size: int = 500
|
||||
) -> None:
|
||||
"""Bulk upsert nodes using UNWIND for fewer round-trips."""
|
||||
if not nodes:
|
||||
return
|
||||
|
||||
if batch_size <= 0 or batch_size > 5000:
|
||||
raise ValueError("batch_size must be between 1 and 5000 for bulk node upserts")
|
||||
|
||||
cleaned = []
|
||||
for node_id, data in nodes:
|
||||
label = self._normalize_node_id(node_id)
|
||||
if "entity_id" not in data:
|
||||
raise ValueError("PostgreSQL: node properties must contain an 'entity_id' field")
|
||||
node_props = dict(data)
|
||||
node_props["entity_id"] = label
|
||||
cleaned.append({"entity_id": label, "properties": node_props})
|
||||
|
||||
for i in range(0, len(cleaned), batch_size):
|
||||
batch = cleaned[i : i + batch_size]
|
||||
cy_params = {"params": json.dumps({"nodes": batch}, ensure_ascii=False)}
|
||||
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND $nodes AS n
|
||||
MERGE (v:base {{entity_id: n.entity_id}})
|
||||
SET v += n.properties
|
||||
RETURN count(v) AS upserted
|
||||
$$, $1::agtype) AS (upserted bigint)"""
|
||||
|
||||
try:
|
||||
await self._query(
|
||||
query, readonly=False, upsert=True, params=cy_params, timeout=DEFAULT_LLM_TIMEOUT
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"[{self.workspace}] POSTGRES, bulk upsert_node failed on batch starting with `{batch[0]['entity_id']}`"
|
||||
)
|
||||
raise
|
||||
|
||||
async def upsert_edges_bulk(
|
||||
self,
|
||||
edges: list[tuple[str, str, dict[str, str]]],
|
||||
batch_size: int = 500,
|
||||
) -> None:
|
||||
"""Bulk upsert edges using UNWIND; keeps current undirected semantics."""
|
||||
if not edges:
|
||||
return
|
||||
|
||||
if batch_size <= 0 or batch_size > 5000:
|
||||
raise ValueError("batch_size must be between 1 and 5000 for bulk edge upserts")
|
||||
|
||||
cleaned = []
|
||||
for src, tgt, props in edges:
|
||||
src_label = self._normalize_node_id(src)
|
||||
tgt_label = self._normalize_node_id(tgt)
|
||||
cleaned.append(
|
||||
{
|
||||
"src": src_label,
|
||||
"tgt": tgt_label,
|
||||
"properties": dict(props),
|
||||
}
|
||||
)
|
||||
|
||||
for i in range(0, len(cleaned), batch_size):
|
||||
batch = cleaned[i : i + batch_size]
|
||||
cy_params = {"params": json.dumps({"edges": batch}, ensure_ascii=False)}
|
||||
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND $edges AS e
|
||||
MATCH (source:base {{entity_id: e.src}})
|
||||
WITH source, e
|
||||
MATCH (target:base {{entity_id: e.tgt}})
|
||||
MERGE (source)-[r:DIRECTED]-(target)
|
||||
SET r += e.properties
|
||||
SET r += e.properties
|
||||
RETURN count(r) AS upserted
|
||||
$$, $1::agtype) AS (upserted bigint)"""
|
||||
|
||||
try:
|
||||
await self._query(
|
||||
query,
|
||||
readonly=False,
|
||||
upsert=True,
|
||||
params=cy_params,
|
||||
timeout=DEFAULT_LLM_TIMEOUT,
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"[{self.workspace}] POSTGRES, bulk upsert_edge failed on batch starting with `{batch[0]['src']}`-`{batch[0]['tgt']}`"
|
||||
)
|
||||
raise
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Delete a node from the graph.
|
||||
|
|
@ -3967,13 +4085,13 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
"""
|
||||
label = self._normalize_node_id(node_id)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base {{entity_id: $1}})
|
||||
DETACH DELETE n
|
||||
$$) AS (n agtype)""" % (self.graph_name, label)
|
||||
$$, $1) AS (n agtype)"""
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
await self._query(query, readonly=False, params={"node_id": label})
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error during node deletion: {e}")
|
||||
raise
|
||||
|
|
@ -3986,16 +4104,21 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
node_ids (list[str]): A list of node IDs to remove.
|
||||
"""
|
||||
node_ids = [self._normalize_node_id(node_id) for node_id in node_ids]
|
||||
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
||||
if not node_ids:
|
||||
return
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base)
|
||||
WHERE n.entity_id IN [%s]
|
||||
unique_ids = list(dict.fromkeys(node_ids))
|
||||
cy_params = {"params": json.dumps({"node_ids": unique_ids}, ensure_ascii=False)}
|
||||
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
DETACH DELETE n
|
||||
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
||||
$$, $1::agtype) AS (n agtype)"""
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
await self._query(query, readonly=False, params=cy_params)
|
||||
logger.debug(f"[{self.workspace}] Removed {len(unique_ids)} nodes from graph")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error during node removal: {e}")
|
||||
raise
|
||||
|
|
@ -4007,23 +4130,38 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
Args:
|
||||
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
||||
"""
|
||||
if not edges:
|
||||
return
|
||||
|
||||
cleaned_edges: list[tuple[str, str]] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for source, target in edges:
|
||||
src_label = self._normalize_node_id(source)
|
||||
tgt_label = self._normalize_node_id(target)
|
||||
pair = (self._normalize_node_id(source), self._normalize_node_id(target))
|
||||
if pair not in seen:
|
||||
seen.add(pair)
|
||||
cleaned_edges.append(pair)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
||||
if not cleaned_edges:
|
||||
return
|
||||
|
||||
literal_pairs = ", ".join(
|
||||
[f"{{src: {json.dumps(src)}, tgt: {json.dumps(tgt)}}}" for src, tgt in cleaned_edges]
|
||||
)
|
||||
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{literal_pairs}] AS pair
|
||||
MATCH (a:base {{entity_id: pair.src}})-[r]-(b:base {{entity_id: pair.tgt}})
|
||||
DELETE r
|
||||
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
||||
$$) AS (r agtype)"""
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
|
||||
raise
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Deleted {len(cleaned_edges)} edges (undirected)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_batch(
|
||||
self, node_ids: list[str], batch_size: int = 1000
|
||||
|
|
@ -4311,8 +4449,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
|
||||
pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)}
|
||||
|
||||
forward_results = await self._query(sql_fwd, params=pg_params)
|
||||
backward_results = await self._query(sql_bwd, params=pg_params)
|
||||
forward_results, backward_results = await asyncio.gather(
|
||||
self._query(sql_fwd, params=pg_params),
|
||||
self._query(sql_bwd, params=pg_params)
|
||||
)
|
||||
|
||||
for result in forward_results:
|
||||
if result["source"] and result["target"] and result["edge_properties"]:
|
||||
|
|
@ -4376,31 +4516,26 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
|
||||
for i in range(0, len(unique_ids), batch_size):
|
||||
batch = unique_ids[i : i + batch_size]
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join([f'"{n}"' for n in batch])
|
||||
cy_params = {"params": json.dumps({"node_ids": batch}, ensure_ascii=False)}
|
||||
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n:base)-[]->(connected:base)
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
$$, $1::agtype) AS (node_id text, connected_id text)"""
|
||||
|
||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n:base)<-[]-(connected:base)
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids,
|
||||
)
|
||||
$$, $1::agtype) AS (node_id text, connected_id text)"""
|
||||
|
||||
outgoing_results = await self._query(outgoing_query)
|
||||
incoming_results = await self._query(incoming_query)
|
||||
outgoing_results, incoming_results = await asyncio.gather(
|
||||
self._query(outgoing_query, params=cy_params),
|
||||
self._query(incoming_query, params=cy_params)
|
||||
)
|
||||
|
||||
for result in outgoing_results:
|
||||
if result["node_id"] and result["connected_id"]:
|
||||
|
|
@ -4428,15 +4563,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
Returns:
|
||||
list[str]: A list of all labels in the graph.
|
||||
"""
|
||||
query = (
|
||||
"""SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base)
|
||||
WHERE n.entity_id IS NOT NULL
|
||||
RETURN DISTINCT n.entity_id AS label
|
||||
ORDER BY n.entity_id
|
||||
$$) AS (label text)"""
|
||||
% self.graph_name
|
||||
)
|
||||
# Use native SQL for better performance
|
||||
query = f"""
|
||||
SELECT DISTINCT
|
||||
(ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text AS label
|
||||
FROM {self.graph_name}.base
|
||||
WHERE ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]) IS NOT NULL
|
||||
ORDER BY label
|
||||
"""
|
||||
|
||||
results = await self._query(query)
|
||||
labels = []
|
||||
|
|
@ -4471,12 +4605,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
|
||||
# Get starting node data
|
||||
label = self._normalize_node_id(node_label)
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base {{entity_id: $1}})
|
||||
RETURN id(n) as node_id, n
|
||||
$$) AS (node_id bigint, n agtype)""" % (self.graph_name, label)
|
||||
$$, $1) AS (node_id bigint, n agtype)"""
|
||||
|
||||
node_result = await self._query(query)
|
||||
node_result = await self._query(query, params={"node_id": label})
|
||||
if not node_result or not node_result[0].get("n"):
|
||||
return result
|
||||
|
||||
|
|
@ -4525,14 +4659,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
continue
|
||||
|
||||
# Prepare node IDs list
|
||||
node_ids = [node.labels[0] for node in current_level_nodes]
|
||||
formatted_ids = ", ".join(
|
||||
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
|
||||
)
|
||||
node_ids = [self._normalize_node_id(node.labels[0]) for node in current_level_nodes]
|
||||
cy_params = {"params": json.dumps({"node_ids": node_ids}, ensure_ascii=False)}
|
||||
|
||||
# Construct batch query for outgoing edges
|
||||
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{formatted_ids}] AS node_id
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n)-[r]->(neighbor:base)
|
||||
RETURN node_id AS current_id,
|
||||
|
|
@ -4543,12 +4675,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
r,
|
||||
neighbor,
|
||||
true AS is_outgoing
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
$$, $1::agtype) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||
|
||||
# Construct batch query for incoming edges
|
||||
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{formatted_ids}] AS node_id
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
|
||||
RETURN node_id AS current_id,
|
||||
|
|
@ -4559,12 +4691,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
r,
|
||||
neighbor,
|
||||
false AS is_outgoing
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
$$, $1::agtype) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||
|
||||
# Execute queries
|
||||
outgoing_results = await self._query(outgoing_query)
|
||||
incoming_results = await self._query(incoming_query)
|
||||
# Execute queries concurrently
|
||||
outgoing_results, incoming_results = await asyncio.gather(
|
||||
self._query(outgoing_query, params=cy_params),
|
||||
self._query(incoming_query, params=cy_params)
|
||||
)
|
||||
|
||||
# Combine results
|
||||
neighbors = outgoing_results + incoming_results
|
||||
|
|
@ -4654,17 +4788,15 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
|
||||
# Add db_degree to all nodes via bulk query
|
||||
if result.nodes:
|
||||
entity_ids = [node.labels[0] for node in result.nodes]
|
||||
formatted_ids = ", ".join(
|
||||
[f'"{self._normalize_node_id(eid)}"' for eid in entity_ids]
|
||||
)
|
||||
entity_ids = [self._normalize_node_id(node.labels[0]) for node in result.nodes]
|
||||
degree_params = {"params": json.dumps({"node_ids": entity_ids}, ensure_ascii=False)}
|
||||
degree_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{formatted_ids}] AS entity_id
|
||||
UNWIND $node_ids AS entity_id
|
||||
MATCH (n:base {{entity_id: entity_id}})
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
RETURN entity_id, count(r) as degree
|
||||
$$) AS (entity_id text, degree bigint)"""
|
||||
degree_results = await self._query(degree_query)
|
||||
$$, $1::agtype) AS (entity_id text, degree bigint)"""
|
||||
degree_results = await self._query(degree_query, params=degree_params)
|
||||
degree_map = {
|
||||
row["entity_id"]: int(row["degree"]) for row in degree_results
|
||||
}
|
||||
|
|
@ -4753,17 +4885,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
)
|
||||
|
||||
if node_ids:
|
||||
formatted_ids = ", ".join(node_ids)
|
||||
cy_params = {"params": json.dumps({"node_ids": [int(n) for n in node_ids]}, ensure_ascii=False)}
|
||||
# Construct batch query for subgraph within max_nodes
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
WITH [{formatted_ids}] AS node_ids
|
||||
WITH $node_ids AS node_ids
|
||||
MATCH (a)
|
||||
WHERE id(a) IN node_ids
|
||||
OPTIONAL MATCH (a)-[r]->(b)
|
||||
WHERE id(b) IN node_ids
|
||||
RETURN a, r, b
|
||||
$$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
|
||||
results = await self._query(query)
|
||||
$$, $1::agtype) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
|
||||
results = await self._query(query, params=cy_params)
|
||||
|
||||
# Process query results, deduplicate nodes and edges
|
||||
nodes_dict = {}
|
||||
|
|
|
|||
|
|
@ -629,6 +629,9 @@ async def _handle_single_relationship_extraction(
|
|||
)
|
||||
edge_keywords = edge_keywords.replace(",", ",")
|
||||
|
||||
# Derive a relationship label from the first keyword (fallback to description later)
|
||||
relationship_label = edge_keywords.split(",")[0].strip() if edge_keywords else ""
|
||||
|
||||
# Process relationship description with same cleaning pipeline
|
||||
edge_description = sanitize_and_normalize_extracted_text(record_attributes[4])
|
||||
|
||||
|
|
@ -645,6 +648,8 @@ async def _handle_single_relationship_extraction(
|
|||
weight=weight,
|
||||
description=edge_description,
|
||||
keywords=edge_keywords,
|
||||
relationship=relationship_label,
|
||||
type=relationship_label,
|
||||
source_id=edge_source_id,
|
||||
file_path=file_path,
|
||||
timestamp=timestamp,
|
||||
|
|
@ -2337,6 +2342,8 @@ async def _merge_edges_then_upsert(
|
|||
already_source_ids = []
|
||||
already_description = []
|
||||
already_keywords = []
|
||||
already_relationships = []
|
||||
already_types = []
|
||||
already_file_paths = []
|
||||
|
||||
# 1. Get existing edge data from graph storage
|
||||
|
|
@ -2373,6 +2380,20 @@ async def _merge_edges_then_upsert(
|
|||
)
|
||||
)
|
||||
|
||||
if already_edge.get("relationship") is not None:
|
||||
already_relationships.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["relationship"], [GRAPH_FIELD_SEP, ","]
|
||||
)
|
||||
)
|
||||
|
||||
if already_edge.get("type") is not None:
|
||||
already_types.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["type"], [GRAPH_FIELD_SEP, ","]
|
||||
)
|
||||
)
|
||||
|
||||
new_source_ids = [dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||||
|
||||
storage_key = make_relation_chunk_key(src_id, tgt_id)
|
||||
|
|
@ -2474,6 +2495,23 @@ async def _merge_edges_then_upsert(
|
|||
# Join all unique keywords with commas
|
||||
keywords = ",".join(sorted(all_keywords))
|
||||
|
||||
# 6.3 Finalize relationship/type labels from explicit field or fallback to keywords
|
||||
rel_labels = set()
|
||||
for edge in edges_data:
|
||||
if edge.get("relationship"):
|
||||
rel_labels.update(k.strip() for k in edge["relationship"].split(",") if k.strip())
|
||||
rel_labels.update(k.strip() for k in already_relationships if k.strip())
|
||||
relationship_label = ",".join(sorted(rel_labels))
|
||||
if not relationship_label and keywords:
|
||||
relationship_label = keywords.split(",")[0]
|
||||
|
||||
type_labels = set()
|
||||
for edge in edges_data:
|
||||
if edge.get("type"):
|
||||
type_labels.update(k.strip() for k in edge["type"].split(",") if k.strip())
|
||||
type_labels.update(k.strip() for k in already_types if k.strip())
|
||||
type_label = ",".join(sorted(type_labels)) or relationship_label
|
||||
|
||||
# 7. Deduplicate by description, keeping first occurrence in the same document
|
||||
unique_edges = {}
|
||||
for dp in edges_data:
|
||||
|
|
@ -2785,6 +2823,8 @@ async def _merge_edges_then_upsert(
|
|||
weight=weight,
|
||||
description=description,
|
||||
keywords=keywords,
|
||||
relationship=relationship_label,
|
||||
type=type_label,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
created_at=edge_created_at,
|
||||
|
|
@ -2797,6 +2837,8 @@ async def _merge_edges_then_upsert(
|
|||
tgt_id=tgt_id,
|
||||
description=description,
|
||||
keywords=keywords,
|
||||
relationship=relationship_label,
|
||||
type=type_label,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
created_at=edge_created_at,
|
||||
|
|
@ -2822,6 +2864,8 @@ async def _merge_edges_then_upsert(
|
|||
rel_vdb_id: {
|
||||
"src_id": src_id,
|
||||
"tgt_id": tgt_id,
|
||||
"relationship": relationship_label,
|
||||
"type": type_label,
|
||||
"source_id": source_id,
|
||||
"content": rel_content,
|
||||
"keywords": keywords,
|
||||
|
|
|
|||
|
|
@ -1332,6 +1332,8 @@ async def _merge_entities_impl(
|
|||
{
|
||||
"description": "concatenate",
|
||||
"keywords": "join_unique_comma",
|
||||
"relationship": "join_unique_comma",
|
||||
"type": "join_unique_comma",
|
||||
"source_id": "join_unique",
|
||||
"file_path": "join_unique",
|
||||
"weight": "max",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue