feat: optimize database query methods for improved performance and readability

This commit is contained in:
Matt23-star 2025-08-28 16:18:15 -07:00
parent 9804a1885b
commit aa1ef3f053

View file

@ -3100,114 +3100,93 @@ class PGGraphStorage(BaseGraphStorage):
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = self._normalize_node_id(node_id) query = f"""
SELECT EXISTS (
SELECT 1
FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(
VARIADIC ARRAY[properties, '"entity_id"'::agtype]
) = (to_json($1::text)::text)::agtype
LIMIT 1
) AS node_exists;
"""
query = """SELECT * FROM cypher('%s', $$ params = {"node_id": node_id}
MATCH (n:base {entity_id: "%s"}) row = (await self._query(query, params=params))[0]
RETURN count(n) > 0 AS node_exists return bool(row["node_exists"])
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = self._normalize_node_id(source_node_id) query = f"""
tgt_label = self._normalize_node_id(target_node_id) WITH a AS (
SELECT id AS vid
query = """SELECT * FROM cypher('%s', $$ FROM {self.graph_name}.base
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) WHERE ag_catalog.agtype_access_operator(
RETURN COUNT(r) > 0 AS edge_exists VARIADIC ARRAY[properties, '"entity_id"'::agtype]
$$) AS (edge_exists bool)""" % ( ) = (to_json($1::text)::text)::agtype
self.graph_name, ),
src_label, b AS (
tgt_label, SELECT id AS vid
) FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(
single_result = (await self._query(query))[0] VARIADIC ARRAY[properties, '"entity_id"'::agtype]
) = (to_json($2::text)::text)::agtype
return single_result["edge_exists"] )
SELECT EXISTS (
SELECT 1
FROM {self.graph_name}."DIRECTED" d
JOIN a ON d.start_id = a.vid
JOIN b ON d.end_id = b.vid
LIMIT 1
)
OR EXISTS (
SELECT 1
FROM {self.graph_name}."DIRECTED" d
JOIN a ON d.end_id = a.vid
JOIN b ON d.start_id = b.vid
LIMIT 1
) AS edge_exists;
"""
params = {
"source_node_id": source_node_id,
"target_node_id": target_node_id,
}
row = (await self._query(query, params=params))[0]
return bool(row["edge_exists"])
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties""" """Get node by its label identifier, return only node properties"""
label = self._normalize_node_id(node_id) label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
RETURN n
$$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query)
if record:
node = record[0]
node_dict = node["n"]["properties"]
# Process string result, parse it to JSON dictionary result = await self.get_nodes_batch(node_ids=[label])
if isinstance(node_dict, str): if result and node_id in result:
try: return result[node_id]
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string: {node_dict}"
)
return node_dict
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
label = self._normalize_node_id(node_id) label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$ result = await self.node_degrees_batch(node_ids=[label])
MATCH (n:base {entity_id: "%s"})-[r]-() if result and node_id in result:
RETURN count(r) AS total_edge_count return result[node_id]
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0]
if record:
edge_count = int(record["total_edge_count"])
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)])
src_degree = 0 if src_degree is None else src_degree if result and (src_id, tgt_id) in result:
trg_degree = 0 if trg_degree is None else trg_degree return result[(src_id, tgt_id)]
degrees = int(src_degree) + int(trg_degree)
return degrees
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get edge properties between two nodes""" """Get edge properties between two nodes"""
src_label = self._normalize_node_id(source_node_id) src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id) tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$ result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) if result and (src_label, tgt_label) in result:
RETURN properties(r) as edge_properties return result[(src_label, tgt_label)]
LIMIT 1 return None
$$) AS (edge_properties agtype)""" % (
self.graph_name,
src_label,
tgt_label,
)
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
# Process string result, parse it to JSON dictionary
if isinstance(result, str):
try:
result = json.loads(result)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge string: {result}"
)
return result
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """