updated fetch with db side processing

This commit is contained in:
Alexander Belikov 2025-11-07 14:16:53 +01:00
parent 17faf065e5
commit af33abee40

View file

@ -202,8 +202,95 @@ class TigerGraphStorage(BaseGraphStorage):
f"[{self.workspace}] Could not create edge type '{edge_type}': {str(e)}"
)
# Install GSQL queries for efficient operations
self._install_queries(workspace_label)
await asyncio.to_thread(_create_schema)
def _install_queries(self, workspace_label: str):
"""Install GSQL queries for efficient graph operations."""
try:
# Query to get popular labels by degree
# This query counts edges per vertex and returns sorted by degree
popular_labels_query = f"""
CREATE QUERY get_popular_labels_{workspace_label}(INT limit) FOR GRAPH {self._graph_name} {{
MapAccum<STRING, INT> @@degree_map;
HeapAccum<Tuple2<INT, STRING>>(limit, f0 DESC, f1 ASC) @@top_labels;
# Initialize all vertices with degree 0
Start = {{{workspace_label}}};
Start = SELECT v FROM Start:v
WHERE v.entity_id != ""
ACCUM @@degree_map += (v.entity_id -> 0);
# Count edges (both directions for undirected graph)
Start = SELECT v FROM Start:v - (DIRECTED:e) - {workspace_label}:t
WHERE v.entity_id != "" AND t.entity_id != ""
ACCUM @@degree_map += (v.entity_id -> 1);
# Build heap with degree and label, sorted by degree DESC, label ASC
Start = SELECT v FROM Start:v
WHERE v.entity_id != ""
POST-ACCUM
INT degree = @@degree_map.get(v.entity_id),
@@top_labels += Tuple2(degree, v.entity_id);
# Extract labels from heap (already sorted)
ListAccum<STRING> @@result;
FOREACH item IN @@top_labels DO
@@result += item.f1;
END;
PRINT @@result;
}}
"""
# Query to search labels with fuzzy matching
# This query filters vertices by entity_id containing the search query
search_labels_query = f"""
CREATE QUERY search_labels_{workspace_label}(STRING search_query, INT limit) FOR GRAPH {self._graph_name} {{
ListAccum<STRING> @@matches;
STRING query_lower = lower(search_query);
Start = {{{workspace_label}}};
Start = SELECT v FROM Start:v
WHERE v.entity_id != "" AND str_contains(lower(v.entity_id), query_lower)
ACCUM @@matches += v.entity_id;
PRINT @@matches;
}}
"""
# Try to install queries (drop first if they exist)
try:
# Drop existing queries if they exist
try:
self._conn.gsql(f"DROP QUERY get_popular_labels_{workspace_label}")
except Exception:
pass # Query doesn't exist, which is fine
try:
self._conn.gsql(f"DROP QUERY search_labels_{workspace_label}")
except Exception:
pass # Query doesn't exist, which is fine
# Install new queries
self._conn.gsql(popular_labels_query)
self._conn.gsql(search_labels_query)
logger.info(
f"[{self.workspace}] Installed GSQL queries for workspace '{workspace_label}'"
)
except Exception as e:
logger.warning(
f"[{self.workspace}] Could not install GSQL queries: {str(e)}. "
"Will fall back to traversal-based methods."
)
except Exception as e:
logger.warning(
f"[{self.workspace}] Error installing GSQL queries: {str(e)}. "
"Will fall back to traversal-based methods."
)
async def finalize(self):
"""Close the TigerGraph connection and release all resources"""
async with get_graph_db_lock():
@ -1028,6 +1115,31 @@ class TigerGraphStorage(BaseGraphStorage):
def _get_popular_labels():
try:
# Try to use installed GSQL query first
query_name = f"get_popular_labels_{workspace_label}"
try:
result = self._conn.runInstalledQuery(
query_name, params={"limit": limit}
)
if result and len(result) > 0:
# Extract labels from query result
# Result format: [{"@@result": ["label1", "label2", ...]}]
labels = []
for record in result:
if "@@result" in record:
labels.extend(record["@@result"])
# GSQL query already returns sorted labels (by degree DESC, label ASC)
# Just return the limited results
if labels:
return labels[:limit]
except Exception as query_error:
logger.debug(
f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. "
"Falling back to traversal method."
)
# Fallback to traversal method if GSQL query fails
# Get all vertices and calculate degrees
vertices = self._conn.getVertices(workspace_label, limit=100000)
node_degrees = {}
@ -1084,6 +1196,68 @@ class TigerGraphStorage(BaseGraphStorage):
def _search_labels():
try:
# Try to use installed GSQL query first
query_name = f"search_labels_{workspace_label}"
try:
result = self._conn.runInstalledQuery(
query_name, params={"search_query": query_strip, "limit": limit}
)
if result and len(result) > 0:
# Extract labels from query result
labels = []
for record in result:
if "@@matches" in record:
labels.extend(record["@@matches"])
if labels:
# GSQL query does basic filtering, we still need to score and sort
# Score the results (exact match, prefix match, contains match)
matches = []
for entity_id_str in labels:
if is_chinese:
# For Chinese, use direct contains
if query_strip not in entity_id_str:
continue
# Calculate relevance score
if entity_id_str == query_strip:
score = 1000
elif entity_id_str.startswith(query_strip):
score = 500
else:
score = 100 - len(entity_id_str)
else:
# For non-Chinese, use case-insensitive contains
entity_id_lower = entity_id_str.lower()
if query_lower not in entity_id_lower:
continue
# Calculate relevance score
if entity_id_lower == query_lower:
score = 1000
elif entity_id_lower.startswith(query_lower):
score = 500
else:
score = 100 - len(entity_id_str)
# Bonus for word boundary matches
if (
f" {query_lower}" in entity_id_lower
or f"_{query_lower}" in entity_id_lower
):
score += 50
matches.append((entity_id_str, score))
# Sort by relevance score (desc) then alphabetically
matches.sort(key=lambda x: (-x[1], x[0]))
# Return top matches
return [match[0] for match in matches[:limit]]
except Exception as query_error:
logger.debug(
f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. "
"Falling back to traversal method."
)
# Fallback to traversal method if GSQL query fails
# Get all vertices and filter
vertices = self._conn.getVertices(workspace_label, limit=100000)
matches = []