diff --git a/lightrag/kg/tigergraph_impl.py b/lightrag/kg/tigergraph_impl.py index aef515e5..5d7db63f 100644 --- a/lightrag/kg/tigergraph_impl.py +++ b/lightrag/kg/tigergraph_impl.py @@ -202,8 +202,95 @@ class TigerGraphStorage(BaseGraphStorage): f"[{self.workspace}] Could not create edge type '{edge_type}': {str(e)}" ) + # Install GSQL queries for efficient operations + self._install_queries(workspace_label) + await asyncio.to_thread(_create_schema) + def _install_queries(self, workspace_label: str): + """Install GSQL queries for efficient graph operations.""" + try: + # Query to get popular labels by degree + # This query counts edges per vertex and returns sorted by degree + popular_labels_query = f""" + CREATE QUERY get_popular_labels_{workspace_label}(INT limit) FOR GRAPH {self._graph_name} {{ + MapAccum @@degree_map; + HeapAccum>(limit, f0 DESC, f1 ASC) @@top_labels; + + # Initialize all vertices with degree 0 + Start = {{{workspace_label}}}; + Start = SELECT v FROM Start:v + WHERE v.entity_id != "" + ACCUM @@degree_map += (v.entity_id -> 0); + + # Count edges (both directions for undirected graph) + Start = SELECT v FROM Start:v - (DIRECTED:e) - {workspace_label}:t + WHERE v.entity_id != "" AND t.entity_id != "" + ACCUM @@degree_map += (v.entity_id -> 1); + + # Build heap with degree and label, sorted by degree DESC, label ASC + Start = SELECT v FROM Start:v + WHERE v.entity_id != "" + POST-ACCUM + INT degree = @@degree_map.get(v.entity_id), + @@top_labels += Tuple2(degree, v.entity_id); + + # Extract labels from heap (already sorted) + ListAccum @@result; + FOREACH item IN @@top_labels DO + @@result += item.f1; + END; + + PRINT @@result; + }} + """ + + # Query to search labels with fuzzy matching + # This query filters vertices by entity_id containing the search query + search_labels_query = f""" + CREATE QUERY search_labels_{workspace_label}(STRING search_query, INT limit) FOR GRAPH {self._graph_name} {{ + ListAccum @@matches; + STRING query_lower = lower(search_query); + + Start = {{{workspace_label}}}; + Start = SELECT v FROM Start:v + WHERE v.entity_id != "" AND str_contains(lower(v.entity_id), query_lower) + ACCUM @@matches += v.entity_id; + + PRINT @@matches; + }} + """ + + # Try to install queries (drop first if they exist) + try: + # Drop existing queries if they exist + try: + self._conn.gsql(f"DROP QUERY get_popular_labels_{workspace_label}") + except Exception: + pass # Query doesn't exist, which is fine + + try: + self._conn.gsql(f"DROP QUERY search_labels_{workspace_label}") + except Exception: + pass # Query doesn't exist, which is fine + + # Install new queries + self._conn.gsql(popular_labels_query) + self._conn.gsql(search_labels_query) + logger.info( + f"[{self.workspace}] Installed GSQL queries for workspace '{workspace_label}'" + ) + except Exception as e: + logger.warning( + f"[{self.workspace}] Could not install GSQL queries: {str(e)}. " + "Will fall back to traversal-based methods." + ) + except Exception as e: + logger.warning( + f"[{self.workspace}] Error installing GSQL queries: {str(e)}. " + "Will fall back to traversal-based methods." + ) + async def finalize(self): """Close the TigerGraph connection and release all resources""" async with get_graph_db_lock(): @@ -1028,6 +1115,31 @@ class TigerGraphStorage(BaseGraphStorage): def _get_popular_labels(): try: + # Try to use installed GSQL query first + query_name = f"get_popular_labels_{workspace_label}" + try: + result = self._conn.runInstalledQuery( + query_name, params={"limit": limit} + ) + if result and len(result) > 0: + # Extract labels from query result + # Result format: [{"@@result": ["label1", "label2", ...]}] + labels = [] + for record in result: + if "@@result" in record: + labels.extend(record["@@result"]) + + # GSQL query already returns sorted labels (by degree DESC, label ASC) + # Just return the limited results + if labels: + return labels[:limit] + except Exception as query_error: + logger.debug( + f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. " + "Falling back to traversal method." + ) + + # Fallback to traversal method if GSQL query fails # Get all vertices and calculate degrees vertices = self._conn.getVertices(workspace_label, limit=100000) node_degrees = {} @@ -1084,6 +1196,68 @@ class TigerGraphStorage(BaseGraphStorage): def _search_labels(): try: + # Try to use installed GSQL query first + query_name = f"search_labels_{workspace_label}" + try: + result = self._conn.runInstalledQuery( + query_name, params={"search_query": query_strip, "limit": limit} + ) + if result and len(result) > 0: + # Extract labels from query result + labels = [] + for record in result: + if "@@matches" in record: + labels.extend(record["@@matches"]) + + if labels: + # GSQL query does basic filtering, we still need to score and sort + # Score the results (exact match, prefix match, contains match) + matches = [] + for entity_id_str in labels: + if is_chinese: + # For Chinese, use direct contains + if query_strip not in entity_id_str: + continue + # Calculate relevance score + if entity_id_str == query_strip: + score = 1000 + elif entity_id_str.startswith(query_strip): + score = 500 + else: + score = 100 - len(entity_id_str) + else: + # For non-Chinese, use case-insensitive contains + entity_id_lower = entity_id_str.lower() + if query_lower not in entity_id_lower: + continue + # Calculate relevance score + if entity_id_lower == query_lower: + score = 1000 + elif entity_id_lower.startswith(query_lower): + score = 500 + else: + score = 100 - len(entity_id_str) + # Bonus for word boundary matches + if ( + f" {query_lower}" in entity_id_lower + or f"_{query_lower}" in entity_id_lower + ): + score += 50 + + matches.append((entity_id_str, score)) + + # Sort by relevance score (desc) then alphabetically + matches.sort(key=lambda x: (-x[1], x[0])) + + # Return top matches + return [match[0] for match in matches[:limit]] + except Exception as query_error: + logger.debug( + f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. " + "Falling back to traversal method." + ) + + # Fallback to traversal method if GSQL query fails # Get all vertices and filter vertices = self._conn.getVertices(workspace_label, limit=100000) matches = []