From a6b87df75835255dbfd5d12c03730766a375ccae Mon Sep 17 00:00:00 2001 From: clssck Date: Wed, 3 Dec 2025 18:19:26 +0000 Subject: [PATCH] feat(postgres): add bulk operations and health check - Implement bulk upsert_nodes/edges via UNWIND reducing round trips - Add health_check for graph connectivity and AGE catalog status - Switch to parameterized queries preventing Cypher injection - Fix node ID sanitization: strip control chars, escape quotes --- Dockerfile | 5 + docker-compose.test.yml | 9 +- lightrag/api/lightrag_server.py | 4 + lightrag/base.py | 24 ++ lightrag/kg/postgres_impl.py | 472 ++++++++++++++++++++------------ lightrag/operate.py | 44 +++ lightrag/utils_graph.py | 2 + 7 files changed, 388 insertions(+), 172 deletions(-) diff --git a/Dockerfile b/Dockerfile index aaa3c84b..92e7f41f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -68,6 +68,11 @@ FROM python:3.12-slim WORKDIR /app +# Add curl for runtime healthchecks and simple diagnostics +RUN apt-get update \ + && apt-get install -y --no-install-recommends curl \ + && rm -rf /var/lib/apt/lists/* + # Install uv for package management COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 39da4319..8286161f 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -54,8 +54,12 @@ services: volumes: - ./data/rag_storage_test:/app/data/rag_storage - ./data/inputs_test:/app/data/inputs - - ./lightrag:/app/lightrag # Mount source for live reload + # Live reload: Use absolute host path for Docker-in-Docker compatibility (Coder) + - /var/lib/docker/volumes/coder-shared-projects-optimized/_data/LightRAG/lightrag:/app/lightrag environment: + # Live reload: PYTHONPATH makes mounted /app/lightrag take precedence over site-packages + - PYTHONPATH=/app + # Server - HOST=0.0.0.0 - PORT=9621 @@ -120,7 +124,8 @@ services: entrypoint: [] command: - python - - /app/lightrag/api/run_with_gunicorn.py + - -m + - lightrag.api.run_with_gunicorn - --workers - "8" - --llm-binding diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 73a82293..6fb50dfb 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1266,6 +1266,9 @@ def create_app(args): else: auth_mode = "enabled" + # Optional graph health probe (lightweight) - Using unified health_check interface + graph_health = await rag.chunk_entity_relation_graph.health_check() + # Cleanup expired keyed locks and get status keyed_lock_info = cleanup_keyed_lock() @@ -1319,6 +1322,7 @@ def create_app(args): "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, + "graph": graph_health, } except Exception as e: logger.error(f"Error getting health status: {str(e)}") diff --git a/lightrag/base.py b/lightrag/base.py index 4ae1e4b4..19b44b0b 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -187,6 +187,14 @@ class StorageNameSpace(ABC): async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" + async def health_check(self, max_retries: int = 3) -> dict[str, Any]: + """Check the health status of the storage + + Returns: + dict[str, Any]: Health status dictionary with at least 'status' field + """ + return {"status": "healthy"} + @abstractmethod async def drop(self) -> dict[str, str]: """Drop all data from storage and clean up resources @@ -527,6 +535,22 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[node_id] = edges if edges is not None else [] return result + async def upsert_nodes_bulk( + self, nodes: list[tuple[str, dict[str, str]]], batch_size: int = 500 + ) -> None: + """Default bulk helper; storage backends can override for batching.""" + for node_id, node_data in nodes: + await self.upsert_node(node_id, node_data) + + async def upsert_edges_bulk( + self, + edges: list[tuple[str, str, dict[str, str]]], + batch_size: int = 500, + ) -> None: + """Default bulk helper; storage backends can override for batching.""" + for src, tgt, edge_data in edges: + await self.upsert_edge(src, tgt, edge_data) + @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Insert a new node or update an existing node in the graph. diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index b195043c..030de184 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -3389,32 +3389,31 @@ class PGDocStatusStorage(DocStatusStorage): error_msg = EXCLUDED.error_msg, created_at = EXCLUDED.created_at, updated_at = EXCLUDED.updated_at""" + + batch_data = [] for k, v in data.items(): # Remove timezone information, store utc time in db created_at = parse_datetime(v.get("created_at")) updated_at = parse_datetime(v.get("updated_at")) - # chunks_count, chunks_list, track_id, metadata, and error_msg are optional - await self.db.execute( - sql, - { - "workspace": self.workspace, - "id": k, - "content_summary": v["content_summary"], - "content_length": v["content_length"], - "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, - "status": v["status"], - "file_path": v["file_path"], - "chunks_list": json.dumps(v.get("chunks_list", [])), - "track_id": v.get("track_id"), # Add track_id support - "metadata": json.dumps( - v.get("metadata", {}) - ), # Add metadata support - "error_msg": v.get("error_msg"), # Add error_msg support - "created_at": created_at, # Use the converted datetime object - "updated_at": updated_at, # Use the converted datetime object - }, - ) + batch_data.append(( + self.workspace, + k, + v["content_summary"], + v["content_length"], + v["chunks_count"] if "chunks_count" in v else -1, + v["status"], + v["file_path"], + json.dumps(v.get("chunks_list", [])), + v.get("track_id"), + json.dumps(v.get("metadata", {})), + v.get("error_msg"), + created_at, + updated_at + )) + + if batch_data: + await self.db.executemany(sql, batch_data) async def drop(self) -> dict[str, str]: """Drop the storage""" @@ -3487,19 +3486,23 @@ class PGGraphStorage(BaseGraphStorage): @staticmethod def _normalize_node_id(node_id: str) -> str: - """ - Normalize node ID to ensure special characters are properly handled in Cypher queries. + """Best-effort sanitization for identifiers we interpolate into Cypher. - Args: - node_id: The original node ID - - Returns: - Normalized node ID suitable for Cypher queries + This avoids common parse errors without altering the semantic value. + Control chars are stripped, quotes/backticks are escaped, and we keep + the result ASCII-only to match server expectations. """ - # Escape backslashes - normalized_id = node_id - normalized_id = normalized_id.replace("\\", "\\\\") - normalized_id = normalized_id.replace('"', '\\"') + + # Drop control characters that can break AGE parsing + normalized_id = re.sub(r"[\x00-\x1F]", "", node_id) + + # Escape characters that matter for the interpolated Cypher literal + normalized_id = normalized_id.replace("\\", "\\\\") # backslash + normalized_id = normalized_id.replace('"', '\\"') # double quote + normalized_id = normalized_id.replace("`", "\\`") # backtick + + # Keep it compact and ASCII to avoid encoding surprises + normalized_id = normalized_id.encode("ascii", "ignore").decode("ascii") return normalized_id async def initialize(self): @@ -3549,6 +3552,7 @@ class PGGraphStorage(BaseGraphStorage): f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)', f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)', + f'CREATE UNIQUE INDEX CONCURRENTLY {self.graph_name}_entity_id_unique ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx', ] @@ -3563,10 +3567,44 @@ class PGGraphStorage(BaseGraphStorage): graph_name=self.graph_name, ) - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async def health_check(self, max_retries: int = 3) -> dict[str, Any]: + """Check Postgres graph connectivity and status""" + graph_health = {"status": "unknown"} + try: + # Base connectivity + await self.db.execute("SELECT 1") + + # AGE catalog available? + graphs = await self.db.query( + "SELECT name FROM ag_catalog.ag_graph", + multirows=True, + with_age=False, + ) + + graph_exists = False + if graphs: + names = [g.get("name") for g in graphs if "name" in g] + if self.graph_name in names: + graph_exists = True + + # Basic Cypher query test on this workspace's graph + if graph_exists: + await self.db.query( + f"SELECT * FROM cypher('{self.graph_name}', $$ RETURN 1 $$) AS (one agtype)", + with_age=True, + graph_name=self.graph_name, + ) + + graph_health = { + "status": "healthy", + "graphs": graphs if graphs is not None else [], + "workspace_graph": self.graph_name if graph_exists else None, + } + except Exception as exc: + graph_health = {"status": "unhealthy", "detail": str(exc)} + logger.debug(f"Graph health check failed: {exc}") + + return graph_health async def index_done_callback(self) -> None: # PG handles persistence automatically @@ -3685,31 +3723,9 @@ class PGGraphStorage(BaseGraphStorage): return d - @staticmethod - def _format_properties( - properties: dict[str, Any], _id: Union[str, None] = None - ) -> str: - """ - Convert a dictionary of properties to a string representation that - can be used in a cypher query insert/merge statement. - - Args: - properties (dict[str,str]): a dictionary containing node/edge properties - _id (Union[str, None]): the id of the node or None if none exists - - Returns: - str: the properties dictionary as a properly formatted string - """ - props = [] - # wrap property key in backticks to escape - for k, v in properties.items(): - prop = f"`{k}`: {json.dumps(v)}" - props.append(prop) - if _id is not None and "id" not in properties: - props.append( - f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" - ) - return "{" + ", ".join(props) + "}" + def _format_properties(self, data: dict[str, Any]) -> str: + # Deprecated: Use parameterized queries instead + return "" async def _query( self, @@ -3717,6 +3733,7 @@ class PGGraphStorage(BaseGraphStorage): readonly: bool = True, upsert: bool = False, params: dict[str, Any] | None = None, + timeout: float | None = None, ) -> list[dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an @@ -3740,6 +3757,7 @@ class PGGraphStorage(BaseGraphStorage): else: data = await self.db.execute( query, + data=params, upsert=upsert, with_age=True, graph_name=self.graph_name, @@ -3853,16 +3871,13 @@ class PGGraphStorage(BaseGraphStorage): """ label = self._normalize_node_id(source_node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base {{entity_id: $1}}) OPTIONAL MATCH (n)-[]-(connected:base) RETURN n.entity_id AS source_id, connected.entity_id AS connected_id - $$) AS (source_id text, connected_id text)""" % ( - self.graph_name, - label, - ) + $$, $1) AS (source_id text, connected_id text)""" - results = await self._query(query) + results = await self._query(query, params={"node_id": label}) edges = [] for record in results: source_id = record["source_id"] @@ -3892,20 +3907,24 @@ class PGGraphStorage(BaseGraphStorage): ) label = self._normalize_node_id(node_id) - properties = self._format_properties(node_data) + + # Use parameterized query to prevent injection + cy_params = { + "params": json.dumps( + {"node_id": label, "properties": node_data}, ensure_ascii=False + ) + } - query = """SELECT * FROM cypher('%s', $$ - MERGE (n:base {entity_id: "%s"}) - SET n += %s + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MERGE (n:base {{entity_id: $node_id}}) + SET n += $properties RETURN n - $$) AS (n agtype)""" % ( - self.graph_name, - label, - properties, - ) + $$, $1::agtype) AS (n agtype)""" try: - await self._query(query, readonly=False, upsert=True) + await self._query( + query, readonly=False, upsert=True, params=cy_params + ) except Exception: logger.error( @@ -3931,26 +3950,32 @@ class PGGraphStorage(BaseGraphStorage): """ src_label = self._normalize_node_id(source_node_id) tgt_label = self._normalize_node_id(target_node_id) - edge_properties = self._format_properties(edge_data) + + # Use parameterized query for security + cy_params = { + "params": json.dumps( + { + "src_id": src_label, + "tgt_id": tgt_label, + "properties": edge_data + }, + ensure_ascii=False + ) + } - query = """SELECT * FROM cypher('%s', $$ - MATCH (source:base {entity_id: "%s"}) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (source:base {{entity_id: $src_id}}) WITH source - MATCH (target:base {entity_id: "%s"}) + MATCH (target:base {{entity_id: $tgt_id}}) MERGE (source)-[r:DIRECTED]-(target) - SET r += %s - SET r += %s + SET r += $properties RETURN r - $$) AS (r agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - edge_properties, - edge_properties, # https://github.com/HKUDS/LightRAG/issues/1438#issuecomment-2826000195 - ) + $$, $1::agtype) AS (r agtype)""" try: - await self._query(query, readonly=False, upsert=True) + await self._query( + query, readonly=False, upsert=True, params=cy_params + ) except Exception: logger.error( @@ -3958,6 +3983,99 @@ class PGGraphStorage(BaseGraphStorage): ) raise + async def upsert_nodes_bulk( + self, nodes: list[tuple[str, dict[str, str]]], batch_size: int = 500 + ) -> None: + """Bulk upsert nodes using UNWIND for fewer round-trips.""" + if not nodes: + return + + if batch_size <= 0 or batch_size > 5000: + raise ValueError("batch_size must be between 1 and 5000 for bulk node upserts") + + cleaned = [] + for node_id, data in nodes: + label = self._normalize_node_id(node_id) + if "entity_id" not in data: + raise ValueError("PostgreSQL: node properties must contain an 'entity_id' field") + node_props = dict(data) + node_props["entity_id"] = label + cleaned.append({"entity_id": label, "properties": node_props}) + + for i in range(0, len(cleaned), batch_size): + batch = cleaned[i : i + batch_size] + cy_params = {"params": json.dumps({"nodes": batch}, ensure_ascii=False)} + + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND $nodes AS n + MERGE (v:base {{entity_id: n.entity_id}}) + SET v += n.properties + RETURN count(v) AS upserted + $$, $1::agtype) AS (upserted bigint)""" + + try: + await self._query( + query, readonly=False, upsert=True, params=cy_params, timeout=DEFAULT_LLM_TIMEOUT + ) + except Exception: + logger.error( + f"[{self.workspace}] POSTGRES, bulk upsert_node failed on batch starting with `{batch[0]['entity_id']}`" + ) + raise + + async def upsert_edges_bulk( + self, + edges: list[tuple[str, str, dict[str, str]]], + batch_size: int = 500, + ) -> None: + """Bulk upsert edges using UNWIND; keeps current undirected semantics.""" + if not edges: + return + + if batch_size <= 0 or batch_size > 5000: + raise ValueError("batch_size must be between 1 and 5000 for bulk edge upserts") + + cleaned = [] + for src, tgt, props in edges: + src_label = self._normalize_node_id(src) + tgt_label = self._normalize_node_id(tgt) + cleaned.append( + { + "src": src_label, + "tgt": tgt_label, + "properties": dict(props), + } + ) + + for i in range(0, len(cleaned), batch_size): + batch = cleaned[i : i + batch_size] + cy_params = {"params": json.dumps({"edges": batch}, ensure_ascii=False)} + + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND $edges AS e + MATCH (source:base {{entity_id: e.src}}) + WITH source, e + MATCH (target:base {{entity_id: e.tgt}}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += e.properties + SET r += e.properties + RETURN count(r) AS upserted + $$, $1::agtype) AS (upserted bigint)""" + + try: + await self._query( + query, + readonly=False, + upsert=True, + params=cy_params, + timeout=DEFAULT_LLM_TIMEOUT, + ) + except Exception: + logger.error( + f"[{self.workspace}] POSTGRES, bulk upsert_edge failed on batch starting with `{batch[0]['src']}`-`{batch[0]['tgt']}`" + ) + raise + async def delete_node(self, node_id: str) -> None: """ Delete a node from the graph. @@ -3967,13 +4085,13 @@ class PGGraphStorage(BaseGraphStorage): """ label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base {{entity_id: $1}}) DETACH DELETE n - $$) AS (n agtype)""" % (self.graph_name, label) + $$, $1) AS (n agtype)""" try: - await self._query(query, readonly=False) + await self._query(query, readonly=False, params={"node_id": label}) except Exception as e: logger.error(f"[{self.workspace}] Error during node deletion: {e}") raise @@ -3986,16 +4104,21 @@ class PGGraphStorage(BaseGraphStorage): node_ids (list[str]): A list of node IDs to remove. """ node_ids = [self._normalize_node_id(node_id) for node_id in node_ids] - node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids]) + if not node_ids: + return - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base) - WHERE n.entity_id IN [%s] + unique_ids = list(dict.fromkeys(node_ids)) + cy_params = {"params": json.dumps({"node_ids": unique_ids}, ensure_ascii=False)} + + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND $node_ids AS node_id + MATCH (n:base {{entity_id: node_id}}) DETACH DELETE n - $$) AS (n agtype)""" % (self.graph_name, node_id_list) + $$, $1::agtype) AS (n agtype)""" try: - await self._query(query, readonly=False) + await self._query(query, readonly=False, params=cy_params) + logger.debug(f"[{self.workspace}] Removed {len(unique_ids)} nodes from graph") except Exception as e: logger.error(f"[{self.workspace}] Error during node removal: {e}") raise @@ -4007,23 +4130,38 @@ class PGGraphStorage(BaseGraphStorage): Args: edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). """ + if not edges: + return + + cleaned_edges: list[tuple[str, str]] = [] + seen: set[tuple[str, str]] = set() for source, target in edges: - src_label = self._normalize_node_id(source) - tgt_label = self._normalize_node_id(target) + pair = (self._normalize_node_id(source), self._normalize_node_id(target)) + if pair not in seen: + seen.add(pair) + cleaned_edges.append(pair) - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) + if not cleaned_edges: + return + + literal_pairs = ", ".join( + [f"{{src: {json.dumps(src)}, tgt: {json.dumps(tgt)}}}" for src, tgt in cleaned_edges] + ) + + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND [{literal_pairs}] AS pair + MATCH (a:base {{entity_id: pair.src}})-[r]-(b:base {{entity_id: pair.tgt}}) DELETE r - $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) + $$) AS (r agtype)""" - try: - await self._query(query, readonly=False) - logger.debug( - f"[{self.workspace}] Deleted edge from '{source}' to '{target}'" - ) - except Exception as e: - logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") - raise + try: + await self._query(query, readonly=False) + logger.debug( + f"[{self.workspace}] Deleted {len(cleaned_edges)} edges (undirected)" + ) + except Exception as e: + logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") + raise async def get_nodes_batch( self, node_ids: list[str], batch_size: int = 1000 @@ -4311,8 +4449,10 @@ class PGGraphStorage(BaseGraphStorage): pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)} - forward_results = await self._query(sql_fwd, params=pg_params) - backward_results = await self._query(sql_bwd, params=pg_params) + forward_results, backward_results = await asyncio.gather( + self._query(sql_fwd, params=pg_params), + self._query(sql_bwd, params=pg_params) + ) for result in forward_results: if result["source"] and result["target"] and result["edge_properties"]: @@ -4376,31 +4516,26 @@ class PGGraphStorage(BaseGraphStorage): for i in range(0, len(unique_ids), batch_size): batch = unique_ids[i : i + batch_size] - # Format node IDs for the query - formatted_ids = ", ".join([f'"{n}"' for n in batch]) + cy_params = {"params": json.dumps({"node_ids": batch}, ensure_ascii=False)} - outgoing_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) + outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND $node_ids AS node_id + MATCH (n:base {{entity_id: node_id}}) OPTIONAL MATCH (n:base)-[]->(connected:base) RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) + $$, $1::agtype) AS (node_id text, connected_id text)""" - incoming_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) + incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND $node_ids AS node_id + MATCH (n:base {{entity_id: node_id}}) OPTIONAL MATCH (n:base)<-[]-(connected:base) RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) + $$, $1::agtype) AS (node_id text, connected_id text)""" - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) + outgoing_results, incoming_results = await asyncio.gather( + self._query(outgoing_query, params=cy_params), + self._query(incoming_query, params=cy_params) + ) for result in outgoing_results: if result["node_id"] and result["connected_id"]: @@ -4428,15 +4563,14 @@ class PGGraphStorage(BaseGraphStorage): Returns: list[str]: A list of all labels in the graph. """ - query = ( - """SELECT * FROM cypher('%s', $$ - MATCH (n:base) - WHERE n.entity_id IS NOT NULL - RETURN DISTINCT n.entity_id AS label - ORDER BY n.entity_id - $$) AS (label text)""" - % self.graph_name - ) + # Use native SQL for better performance + query = f""" + SELECT DISTINCT + (ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text AS label + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]) IS NOT NULL + ORDER BY label + """ results = await self._query(query) labels = [] @@ -4471,12 +4605,12 @@ class PGGraphStorage(BaseGraphStorage): # Get starting node data label = self._normalize_node_id(node_label) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base {{entity_id: $1}}) RETURN id(n) as node_id, n - $$) AS (node_id bigint, n agtype)""" % (self.graph_name, label) + $$, $1) AS (node_id bigint, n agtype)""" - node_result = await self._query(query) + node_result = await self._query(query, params={"node_id": label}) if not node_result or not node_result[0].get("n"): return result @@ -4525,14 +4659,12 @@ class PGGraphStorage(BaseGraphStorage): continue # Prepare node IDs list - node_ids = [node.labels[0] for node in current_level_nodes] - formatted_ids = ", ".join( - [f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids] - ) + node_ids = [self._normalize_node_id(node.labels[0]) for node in current_level_nodes] + cy_params = {"params": json.dumps({"node_ids": node_ids}, ensure_ascii=False)} # Construct batch query for outgoing edges outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND [{formatted_ids}] AS node_id + UNWIND $node_ids AS node_id MATCH (n:base {{entity_id: node_id}}) OPTIONAL MATCH (n)-[r]->(neighbor:base) RETURN node_id AS current_id, @@ -4543,12 +4675,12 @@ class PGGraphStorage(BaseGraphStorage): r, neighbor, true AS is_outgoing - $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + $$, $1::agtype) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" # Construct batch query for incoming edges incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND [{formatted_ids}] AS node_id + UNWIND $node_ids AS node_id MATCH (n:base {{entity_id: node_id}}) OPTIONAL MATCH (n)<-[r]-(neighbor:base) RETURN node_id AS current_id, @@ -4559,12 +4691,14 @@ class PGGraphStorage(BaseGraphStorage): r, neighbor, false AS is_outgoing - $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + $$, $1::agtype) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" - # Execute queries - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) + # Execute queries concurrently + outgoing_results, incoming_results = await asyncio.gather( + self._query(outgoing_query, params=cy_params), + self._query(incoming_query, params=cy_params) + ) # Combine results neighbors = outgoing_results + incoming_results @@ -4654,17 +4788,15 @@ class PGGraphStorage(BaseGraphStorage): # Add db_degree to all nodes via bulk query if result.nodes: - entity_ids = [node.labels[0] for node in result.nodes] - formatted_ids = ", ".join( - [f'"{self._normalize_node_id(eid)}"' for eid in entity_ids] - ) + entity_ids = [self._normalize_node_id(node.labels[0]) for node in result.nodes] + degree_params = {"params": json.dumps({"node_ids": entity_ids}, ensure_ascii=False)} degree_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND [{formatted_ids}] AS entity_id + UNWIND $node_ids AS entity_id MATCH (n:base {{entity_id: entity_id}}) OPTIONAL MATCH (n)-[r]-() RETURN entity_id, count(r) as degree - $$) AS (entity_id text, degree bigint)""" - degree_results = await self._query(degree_query) + $$, $1::agtype) AS (entity_id text, degree bigint)""" + degree_results = await self._query(degree_query, params=degree_params) degree_map = { row["entity_id"]: int(row["degree"]) for row in degree_results } @@ -4753,17 +4885,17 @@ class PGGraphStorage(BaseGraphStorage): ) if node_ids: - formatted_ids = ", ".join(node_ids) + cy_params = {"params": json.dumps({"node_ids": [int(n) for n in node_ids]}, ensure_ascii=False)} # Construct batch query for subgraph within max_nodes query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{formatted_ids}] AS node_ids + WITH $node_ids AS node_ids MATCH (a) WHERE id(a) IN node_ids OPTIONAL MATCH (a)-[r]->(b) WHERE id(b) IN node_ids RETURN a, r, b - $$) AS (a AGTYPE, r AGTYPE, b AGTYPE)""" - results = await self._query(query) + $$, $1::agtype) AS (a AGTYPE, r AGTYPE, b AGTYPE)""" + results = await self._query(query, params=cy_params) # Process query results, deduplicate nodes and edges nodes_dict = {} diff --git a/lightrag/operate.py b/lightrag/operate.py index 1bb1527f..7469448f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -629,6 +629,9 @@ async def _handle_single_relationship_extraction( ) edge_keywords = edge_keywords.replace(",", ",") + # Derive a relationship label from the first keyword (fallback to description later) + relationship_label = edge_keywords.split(",")[0].strip() if edge_keywords else "" + # Process relationship description with same cleaning pipeline edge_description = sanitize_and_normalize_extracted_text(record_attributes[4]) @@ -645,6 +648,8 @@ async def _handle_single_relationship_extraction( weight=weight, description=edge_description, keywords=edge_keywords, + relationship=relationship_label, + type=relationship_label, source_id=edge_source_id, file_path=file_path, timestamp=timestamp, @@ -2337,6 +2342,8 @@ async def _merge_edges_then_upsert( already_source_ids = [] already_description = [] already_keywords = [] + already_relationships = [] + already_types = [] already_file_paths = [] # 1. Get existing edge data from graph storage @@ -2373,6 +2380,20 @@ async def _merge_edges_then_upsert( ) ) + if already_edge.get("relationship") is not None: + already_relationships.extend( + split_string_by_multi_markers( + already_edge["relationship"], [GRAPH_FIELD_SEP, ","] + ) + ) + + if already_edge.get("type") is not None: + already_types.extend( + split_string_by_multi_markers( + already_edge["type"], [GRAPH_FIELD_SEP, ","] + ) + ) + new_source_ids = [dp["source_id"] for dp in edges_data if dp.get("source_id")] storage_key = make_relation_chunk_key(src_id, tgt_id) @@ -2474,6 +2495,23 @@ async def _merge_edges_then_upsert( # Join all unique keywords with commas keywords = ",".join(sorted(all_keywords)) + # 6.3 Finalize relationship/type labels from explicit field or fallback to keywords + rel_labels = set() + for edge in edges_data: + if edge.get("relationship"): + rel_labels.update(k.strip() for k in edge["relationship"].split(",") if k.strip()) + rel_labels.update(k.strip() for k in already_relationships if k.strip()) + relationship_label = ",".join(sorted(rel_labels)) + if not relationship_label and keywords: + relationship_label = keywords.split(",")[0] + + type_labels = set() + for edge in edges_data: + if edge.get("type"): + type_labels.update(k.strip() for k in edge["type"].split(",") if k.strip()) + type_labels.update(k.strip() for k in already_types if k.strip()) + type_label = ",".join(sorted(type_labels)) or relationship_label + # 7. Deduplicate by description, keeping first occurrence in the same document unique_edges = {} for dp in edges_data: @@ -2785,6 +2823,8 @@ async def _merge_edges_then_upsert( weight=weight, description=description, keywords=keywords, + relationship=relationship_label, + type=type_label, source_id=source_id, file_path=file_path, created_at=edge_created_at, @@ -2797,6 +2837,8 @@ async def _merge_edges_then_upsert( tgt_id=tgt_id, description=description, keywords=keywords, + relationship=relationship_label, + type=type_label, source_id=source_id, file_path=file_path, created_at=edge_created_at, @@ -2822,6 +2864,8 @@ async def _merge_edges_then_upsert( rel_vdb_id: { "src_id": src_id, "tgt_id": tgt_id, + "relationship": relationship_label, + "type": type_label, "source_id": source_id, "content": rel_content, "keywords": keywords, diff --git a/lightrag/utils_graph.py b/lightrag/utils_graph.py index cedad575..621ed25d 100644 --- a/lightrag/utils_graph.py +++ b/lightrag/utils_graph.py @@ -1332,6 +1332,8 @@ async def _merge_entities_impl( { "description": "concatenate", "keywords": "join_unique_comma", + "relationship": "join_unique_comma", + "type": "join_unique_comma", "source_id": "join_unique", "file_path": "join_unique", "weight": "max",