diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 8848f0fd..dd1ff8ae 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -3100,114 +3100,93 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = self._normalize_node_id(node_id) + query = f""" + SELECT EXISTS ( + SELECT 1 + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($1::text)::text)::agtype + LIMIT 1 + ) AS node_exists; + """ - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN count(n) > 0 AS node_exists - $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) - - single_result = (await self._query(query))[0] - - return single_result["node_exists"] + params = {"node_id": node_id} + row = (await self._query(query, params=params))[0] + return bool(row["node_exists"]) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = self._normalize_node_id(source_node_id) - tgt_label = self._normalize_node_id(target_node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN COUNT(r) > 0 AS edge_exists - $$) AS (edge_exists bool)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - - single_result = (await self._query(query))[0] - - return single_result["edge_exists"] + query = f""" + WITH a AS ( + SELECT id AS vid + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($1::text)::text)::agtype + ), + b AS ( + SELECT id AS vid + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($2::text)::text)::agtype + ) + SELECT EXISTS ( + SELECT 1 + FROM {self.graph_name}."DIRECTED" d + JOIN a ON d.start_id = a.vid + JOIN b ON d.end_id = b.vid + LIMIT 1 + ) + OR EXISTS ( + SELECT 1 + FROM {self.graph_name}."DIRECTED" d + JOIN a ON d.end_id = a.vid + JOIN b ON d.start_id = b.vid + LIMIT 1 + ) AS edge_exists; + """ + params = { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + } + row = (await self._query(query, params=params))[0] + return bool(row["edge_exists"]) async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties""" label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN n - $$) AS (n agtype)""" % (self.graph_name, label) - record = await self._query(query) - if record: - node = record[0] - node_dict = node["n"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string: {node_dict}" - ) - - return node_dict + result = await self.get_nodes_batch(node_ids=[label]) + if result and node_id in result: + return result[node_id] return None async def node_degree(self, node_id: str) -> int: label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"})-[r]-() - RETURN count(r) AS total_edge_count - $$) AS (total_edge_count integer)""" % (self.graph_name, label) - record = (await self._query(query))[0] - if record: - edge_count = int(record["total_edge_count"]) - return edge_count + result = await self.node_degrees_batch(node_ids=[label]) + if result and node_id in result: + return result[node_id] async def edge_degree(self, src_id: str, tgt_id: str) -> int: - src_degree = await self.node_degree(src_id) - trg_degree = await self.node_degree(tgt_id) - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - - return degrees + result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)]) + if result and (src_id, tgt_id) in result: + return result[(src_id, tgt_id)] async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get edge properties between two nodes""" - src_label = self._normalize_node_id(source_node_id) tgt_label = self._normalize_node_id(target_node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN properties(r) as edge_properties - LIMIT 1 - $$) AS (edge_properties agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - record = await self._query(query) - if record and record[0] and record[0]["edge_properties"]: - result = record[0]["edge_properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(result, str): - try: - result = json.loads(result) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge string: {result}" - ) - - return result + result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}]) + if result and (src_label, tgt_label) in result: + return result[(src_label, tgt_label)] + return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """