Add high-performance label search methods to PostgreSQL graph storage
- Add get_popular_labels() method - Add search_labels() with fuzzy matching - Use native SQL for better performance - Include proper scoring and ranking
This commit is contained in:
parent
6f85bd6b19
commit
3296bcb553
1 changed files with 107 additions and 0 deletions
|
|
@ -4259,6 +4259,113 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
edges.append(edge_properties)
|
edges.append(edge_properties)
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
async def get_popular_labels(self, limit: int = 300) -> list[str]:
|
||||||
|
"""Get popular labels by node degree (most connected entities) using native SQL for performance."""
|
||||||
|
try:
|
||||||
|
# Native SQL query to calculate node degrees directly from AGE's underlying tables
|
||||||
|
# This is significantly faster than using the cypher() function wrapper
|
||||||
|
query = f"""
|
||||||
|
WITH node_degrees AS (
|
||||||
|
SELECT
|
||||||
|
node_id,
|
||||||
|
COUNT(*) AS degree
|
||||||
|
FROM (
|
||||||
|
SELECT start_id AS node_id FROM {self.graph_name}._ag_label_edge
|
||||||
|
UNION ALL
|
||||||
|
SELECT end_id AS node_id FROM {self.graph_name}._ag_label_edge
|
||||||
|
) AS all_edges
|
||||||
|
GROUP BY node_id
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
(ag_catalog.agtype_access_operator(VARIADIC ARRAY[v.properties, '"entity_id"'::agtype]))::text AS label
|
||||||
|
FROM
|
||||||
|
node_degrees d
|
||||||
|
JOIN
|
||||||
|
{self.graph_name}._ag_label_vertex v ON d.node_id = v.id
|
||||||
|
WHERE
|
||||||
|
ag_catalog.agtype_access_operator(VARIADIC ARRAY[v.properties, '"entity_id"'::agtype]) IS NOT NULL
|
||||||
|
ORDER BY
|
||||||
|
d.degree DESC,
|
||||||
|
label ASC
|
||||||
|
LIMIT $1;
|
||||||
|
"""
|
||||||
|
results = await self._query(query, params={"limit": limit})
|
||||||
|
labels = [
|
||||||
|
result["label"] for result in results if result and "label" in result
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
|
||||||
|
)
|
||||||
|
return labels
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||||
|
"""Search labels with fuzzy matching using native, parameterized SQL for performance and security."""
|
||||||
|
query_lower = query.lower().strip()
|
||||||
|
if not query_lower:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Re-implementing with the correct agtype access operator and full scoring logic.
|
||||||
|
sql_query = f"""
|
||||||
|
WITH ranked_labels AS (
|
||||||
|
SELECT
|
||||||
|
(ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text AS label,
|
||||||
|
LOWER((ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text) AS label_lower
|
||||||
|
FROM
|
||||||
|
{self.graph_name}._ag_label_vertex
|
||||||
|
WHERE
|
||||||
|
ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]) IS NOT NULL
|
||||||
|
AND LOWER((ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text) ILIKE $1
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
label
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
label,
|
||||||
|
CASE
|
||||||
|
WHEN label_lower = $2 THEN 1000
|
||||||
|
WHEN label_lower LIKE $3 THEN 500
|
||||||
|
ELSE (100 - LENGTH(label))
|
||||||
|
END +
|
||||||
|
CASE
|
||||||
|
WHEN label_lower LIKE $4 OR label_lower LIKE $5 THEN 50
|
||||||
|
ELSE 0
|
||||||
|
END AS score
|
||||||
|
FROM
|
||||||
|
ranked_labels
|
||||||
|
) AS scored_labels
|
||||||
|
ORDER BY
|
||||||
|
score DESC,
|
||||||
|
label ASC
|
||||||
|
LIMIT $6;
|
||||||
|
"""
|
||||||
|
params = (
|
||||||
|
f"%{query_lower}%", # For the main ILIKE clause ($1)
|
||||||
|
query_lower, # For exact match ($2)
|
||||||
|
f"{query_lower}%", # For prefix match ($3)
|
||||||
|
f"% {query_lower}%", # For word boundary (space) ($4)
|
||||||
|
f"%_{query_lower}%", # For word boundary (underscore) ($5)
|
||||||
|
limit, # For LIMIT ($6)
|
||||||
|
)
|
||||||
|
results = await self._query(sql_query, params=dict(enumerate(params, 1)))
|
||||||
|
labels = [
|
||||||
|
result["label"] for result in results if result and "label" in result
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
|
||||||
|
)
|
||||||
|
return labels
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error searching labels with query '{query}': {str(e)}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
async with get_graph_db_lock():
|
async with get_graph_db_lock():
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue