feat: add batch size parameter to node and edge retrieval methods
This commit is contained in:
parent
dc7a6e1c5b
commit
a7da48e05c
1 changed files with 87 additions and 40 deletions
|
|
@ -3384,12 +3384,13 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
|
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
async def get_nodes_batch(self, node_ids: list[str], batch_size: int = 1000) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve multiple nodes in one query using UNWIND.
|
Retrieve multiple nodes in one query using UNWIND.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_ids: List of node entity IDs to fetch.
|
node_ids: List of node entity IDs to fetch.
|
||||||
|
batch_size: Batch size for the query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping each node_id to its node data (or None if not found).
|
A dictionary mapping each node_id to its node data (or None if not found).
|
||||||
|
|
@ -3435,7 +3436,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
return nodes_dict
|
return nodes_dict
|
||||||
|
|
||||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
async def node_degrees_batch(self, node_ids: list[str], batch_size: int = 500) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||||
Calculates the total degree by counting distinct relationships.
|
Calculates the total degree by counting distinct relationships.
|
||||||
|
|
@ -3443,6 +3444,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_ids: List of node labels (entity_id values) to look up.
|
node_ids: List of node labels (entity_id values) to look up.
|
||||||
|
batch_size: Batch size for the query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping each node_id to its degree (total number of relationships).
|
A dictionary mapping each node_id to its degree (total number of relationships).
|
||||||
|
|
@ -3532,7 +3534,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
return edge_degrees_dict
|
return edge_degrees_dict
|
||||||
|
|
||||||
async def get_edges_batch(
|
async def get_edges_batch(
|
||||||
self, pairs: list[dict[str, str]]
|
self, pairs: list[dict[str, str]], batch_size: int = 500
|
||||||
) -> dict[tuple[str, str], dict]:
|
) -> dict[tuple[str, str], dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||||
|
|
@ -3540,6 +3542,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||||
|
batch_size: Batch size for the query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||||
|
|
@ -3547,53 +3550,87 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
if not pairs:
|
if not pairs:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
src_nodes = []
|
seen = set()
|
||||||
tgt_nodes = []
|
uniq_pairs: list[dict[str, str]] = []
|
||||||
for pair in pairs:
|
for p in pairs:
|
||||||
src_nodes.append(self._normalize_node_id(pair["src"]))
|
s = self._normalize_node_id(p["src"])
|
||||||
tgt_nodes.append(self._normalize_node_id(pair["tgt"]))
|
t = self._normalize_node_id(p["tgt"])
|
||||||
|
key = (s, t)
|
||||||
|
if s and t and key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
uniq_pairs.append(p)
|
||||||
|
|
||||||
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
edges_dict: dict[tuple[str, str], dict] = {}
|
||||||
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
|
||||||
|
|
||||||
forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
for i in range(0, len(uniq_pairs), batch_size):
|
||||||
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
batch = uniq_pairs[i:i + batch_size]
|
||||||
UNWIND range(0, size(sources)-1) AS i
|
|
||||||
MATCH (a:base {{entity_id: sources[i]}})-[r]->(b:base {{entity_id: targets[i]}})
|
|
||||||
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
|
||||||
$$) AS (source text, target text, edge_properties agtype)"""
|
|
||||||
|
|
||||||
backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
pairs = [{"src": p["src"], "tgt": p["tgt"]} for p in batch]
|
||||||
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
|
||||||
UNWIND range(0, size(sources)-1) AS i
|
|
||||||
MATCH (a:base {{entity_id: sources[i]}})<-[r]-(b:base {{entity_id: targets[i]}})
|
|
||||||
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
|
||||||
$$) AS (source text, target text, edge_properties agtype)"""
|
|
||||||
|
|
||||||
forward_results = await self._query(forward_query)
|
forward_cypher = """
|
||||||
backward_results = await self._query(backward_query)
|
UNWIND $pairs AS p
|
||||||
|
WITH p.src AS src_eid, p.tgt AS tgt_eid
|
||||||
|
MATCH (a:base {entity_id: src_eid})
|
||||||
|
MATCH (b:base {entity_id: tgt_eid})
|
||||||
|
MATCH (a)-[r]->(b)
|
||||||
|
RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
|
||||||
|
backward_cypher = """
|
||||||
|
UNWIND $pairs AS p
|
||||||
|
WITH p.src AS src_eid, p.tgt AS tgt_eid
|
||||||
|
MATCH (a:base {entity_id: src_eid})
|
||||||
|
MATCH (b:base {entity_id: tgt_eid})
|
||||||
|
MATCH (a)<-[r]-(b)
|
||||||
|
RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
|
||||||
|
|
||||||
edges_dict = {}
|
def dollar_quote(s: str, tag_prefix="AGE"):
|
||||||
|
s = "" if s is None else str(s)
|
||||||
|
for i in itertools.count(1):
|
||||||
|
tag = f"{tag_prefix}{i}"
|
||||||
|
wrapper = f"${tag}$"
|
||||||
|
if wrapper not in s:
|
||||||
|
return f"{wrapper}{s}{wrapper}"
|
||||||
|
|
||||||
for result in forward_results:
|
sql_fwd = f"""
|
||||||
if result["source"] and result["target"] and result["edge_properties"]:
|
SELECT * FROM cypher({dollar_quote(self.graph_name)}::name,
|
||||||
edge_props = result["edge_properties"]
|
{dollar_quote(forward_cypher)}::cstring,
|
||||||
|
$1::agtype)
|
||||||
|
AS (source text, target text, edge_properties agtype)
|
||||||
|
"""
|
||||||
|
|
||||||
# Process string result, parse it to JSON dictionary
|
sql_bwd = f"""
|
||||||
if isinstance(edge_props, str):
|
SELECT * FROM cypher({dollar_quote(self.graph_name)}::name,
|
||||||
try:
|
{dollar_quote(backward_cypher)}::cstring,
|
||||||
edge_props = json.loads(edge_props)
|
$1::agtype)
|
||||||
except json.JSONDecodeError:
|
AS (source text, target text, edge_properties agtype)
|
||||||
logger.warning(
|
"""
|
||||||
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
edges_dict[(result["source"], result["target"])] = edge_props
|
pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)}
|
||||||
|
|
||||||
|
forward_results = await self._query(sql_fwd, params=pg_params)
|
||||||
|
backward_results = await self._query(sql_bwd, params=pg_params)
|
||||||
|
|
||||||
|
for result in forward_results:
|
||||||
|
if result["source"] and result["target"] and result["edge_properties"]:
|
||||||
|
edge_props = result["edge_properties"]
|
||||||
|
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(edge_props, str):
|
||||||
|
try:
|
||||||
|
edge_props = json.loads(edge_props)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse edge properties string: {edge_props}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
edges_dict[(result["source"], result["target"])] = edge_props
|
||||||
|
|
||||||
for result in backward_results:
|
for result in backward_results:
|
||||||
if result["source"] and result["target"] and result["edge_properties"]:
|
if result["source"] and result["target"] and result["edge_properties"]:
|
||||||
edge_props = result["edge_properties"]
|
edge_props = result["edge_properties"]
|
||||||
|
for result in backward_results:
|
||||||
|
if result["source"] and result["target"] and result["edge_properties"]:
|
||||||
|
edge_props = result["edge_properties"]
|
||||||
|
|
||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(edge_props, str):
|
if isinstance(edge_props, str):
|
||||||
|
|
@ -3604,19 +3641,29 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(edge_props, str):
|
||||||
|
try:
|
||||||
|
edge_props = json.loads(edge_props)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse edge properties string: {edge_props}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
edges_dict[(result["source"], result["target"])] = edge_props
|
edges_dict[(result["source"], result["target"])] = edge_props
|
||||||
|
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
async def get_nodes_edges_batch(
|
async def get_nodes_edges_batch(
|
||||||
self, node_ids: list[str]
|
self, node_ids: list[str], batch_size: int = 500
|
||||||
) -> dict[str, list[tuple[str, str]]]:
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
"""
|
"""
|
||||||
Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
|
Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_ids: List of node IDs to get edges for
|
node_ids: List of node IDs to get edges for
|
||||||
|
batch_size: Batch size for the query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping node IDs to lists of (source, target) edge tuples
|
Dictionary mapping node IDs to lists of (source, target) edge tuples
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue