import os import re import asyncio from dataclasses import dataclass from typing import final from enum import StrEnum import configparser from urllib.parse import urlparse from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, ) import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm if not pm.is_installed("pyTigerGraph"): pm.install("pyTigerGraph") from pyTigerGraph import TigerGraphConnection # type: ignore from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) config = configparser.ConfigParser() config.read("config.ini", "utf-8") # Set pyTigerGraph logger level to ERROR to suppress warning logs logging.getLogger("pyTigerGraph").setLevel(logging.ERROR) class VertexType(StrEnum): """Vertex types used in TigerGraph schema.""" ENTITY = "LR_Entity" class EdgeType(StrEnum): """Edge types used in TigerGraph schema.""" RELATES = "LR_Relates" @dataclass class EdgeDefinition: """Definition of an edge type and the vertex types it connects.""" edge_type: EdgeType from_vertex: VertexType to_vertex: VertexType @final @dataclass class TigerGraphStorage(BaseGraphStorage): # Schema definition: list of edges with their vertex connections # This allows for extensibility - add more edge/vertex types here _EDGE_DEFINITIONS: list[EdgeDefinition] | None = None def __init__(self, namespace, global_config, embedding_func, workspace=None): # Initialize edge definitions if not already set if TigerGraphStorage._EDGE_DEFINITIONS is None: TigerGraphStorage._EDGE_DEFINITIONS = [ EdgeDefinition( edge_type=EdgeType.RELATES, from_vertex=VertexType.ENTITY, to_vertex=VertexType.ENTITY, ) ] # Read env and override the arg if present tigergraph_workspace = os.environ.get("TIGERGRAPH_WORKSPACE") if tigergraph_workspace and tigergraph_workspace.strip(): workspace = tigergraph_workspace # Default to 'base' when both arg and env are empty if not workspace or not str(workspace).strip(): workspace = "base" super().__init__( namespace=namespace, workspace=workspace, global_config=global_config, embedding_func=embedding_func, ) self._conn = None self._graph_name = None def _get_workspace_label(self) -> str: """Return workspace label (guaranteed non-empty during initialization)""" return self.workspace def _is_chinese_text(self, text: str) -> bool: """Check if text contains Chinese characters.""" chinese_pattern = re.compile(r"[\u4e00-\u9fff]+") return bool(chinese_pattern.search(text)) def _parse_uri(self, uri: str) -> tuple[str, int]: """Parse URI to extract host and port for TigerGraphConnection. Returns: tuple: (hostname, port) where hostname is just the hostname (no port, no scheme) as pyTigerGraph expects host and port separately or hostname:port format """ parsed = urlparse(uri) hostname = parsed.hostname or "localhost" # Extract port from URI if parsed.port: port = parsed.port elif ":" in parsed.netloc: # Port might be in netloc but not parsed by urlparse (e.g., if scheme is missing) netloc_parts = parsed.netloc.split(":") if len(netloc_parts) >= 2: try: port = int(netloc_parts[-1]) # Take last part as port except ValueError: # Not a valid port, use default port = 9000 if parsed.scheme != "https" else 443 else: port = 9000 if parsed.scheme != "https" else 443 else: # No port in URI, use default port = 9000 if parsed.scheme != "https" else 443 # Return just hostname (pyTigerGraph will handle port separately or via hostname:port format) # But to avoid double port issues, return hostname:port as a single string # This way pyTigerGraph gets exactly what we parsed host_string = f"{hostname}:{port}" return host_string, port def _escape_filter_value(self, value: str) -> str: """Escape special characters in filter values for TigerGraph.""" # Escape single quotes by doubling them return value.replace("'", "''") async def initialize(self): """ Initialize TigerGraph connection and ensure graph and schema exist. Note on GRAPH_NAME vs WORKSPACE: - GRAPH_NAME: The TigerGraph graph instance (like a database in Neo4j). Defaults to sanitized namespace, can be overridden via TIGERGRAPH_GRAPH_NAME env var. Multiple workspaces can share the same graph. - WORKSPACE: Used as a label in the labels SET attribute for data isolation. All vertices use the VertexType.ENTITY.value vertex type, with workspace and entity_type stored as labels in the labels SET attribute (similar to multi-label support in Neo4j/Memgraph). """ async with get_data_init_lock(): URI = os.environ.get( "TIGERGRAPH_URI", config.get("tigergraph", "uri", fallback=None) ) USERNAME = os.environ.get( "TIGERGRAPH_USERNAME", config.get("tigergraph", "username", fallback="tigergraph"), ) PASSWORD = os.environ.get( "TIGERGRAPH_PASSWORD", config.get("tigergraph", "password", fallback=""), ) GRAPH_NAME = os.environ.get( "TIGERGRAPH_GRAPH_NAME", config.get( "tigergraph", "graph_name", fallback=re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace), ), ) # GSQL port (default 14240) GS_PORT = os.environ.get( "TIGERGRAPH_GS_PORT", config.get("tigergraph", "gs_port", fallback="14240"), ) # Convert to int if it's a string try: GS_PORT = int(GS_PORT) if isinstance(GS_PORT, str) else GS_PORT except ValueError: logger.warning( f"[{self.workspace}] Invalid TIGERGRAPH_GS_PORT value '{GS_PORT}', using default 14240" ) GS_PORT = 14240 if not URI: raise ValueError( "TIGERGRAPH_URI is required. Please set it in environment variables or config.ini" ) # Parse URI to get host and port host, port = self._parse_uri(URI) self._graph_name = GRAPH_NAME # Initialize TigerGraph connection (synchronous) # First connect without graphname to check/create graph def _init_connection(): # Parse the original URI to get scheme and hostname parsed = urlparse(URI) scheme = parsed.scheme or "http" hostname = parsed.hostname or "localhost" # pyTigerGraph expects host as "http://hostname" (with protocol, without port) # and restppPort as a separate parameter host_url = f"{scheme}://{hostname}" # Connect without graphname first to check/create graph conn = TigerGraphConnection( host=host_url, # Full URL with protocol, no port restppPort=port, # REST++ port gsPort=GS_PORT, # GSQL port username=USERNAME, password=PASSWORD, ) # Check if graph exists using GSQL try: # Use GSQL to list graphs result = conn.gsql("LS") # Parse the result - GSQL LS command returns a string listing graphs graph_exists = GRAPH_NAME in str(result) if not graph_exists: # Note: In TigerGraph, the graph is created implicitly when the first # schema element (vertex or edge type) is created in _ensure_graph_and_schema logger.info( f"[{self.workspace}] Graph '{GRAPH_NAME}' does not exist. " "It will be created when schema is defined." ) else: logger.debug( f"[{self.workspace}] Graph '{GRAPH_NAME}' already exists." ) except Exception as e: # If GSQL LS fails, try to continue - graph might be created during schema definition logger.debug( f"[{self.workspace}] Could not check graph existence: {str(e)}. " "Will attempt to create during schema definition." ) # Now connect with graphname for operations # Use the same host_url and port as above conn = TigerGraphConnection( host=host_url, # Full URL with protocol, no port restppPort=port, # REST++ port gsPort=GS_PORT, # GSQL port username=USERNAME, password=PASSWORD, graphname=GRAPH_NAME, ) return conn # Run in thread pool to avoid blocking self._conn = await asyncio.to_thread(_init_connection) logger.info( f"[{self.workspace}] Connected to TigerGraph at {host} (graph: {GRAPH_NAME}, workspace: {self.workspace})" ) # Ensure graph and schema exist await self._ensure_graph_and_schema() async def _ensure_graph_and_schema(self): """ Ensure the graph exists and schema is defined with required vertex and edge types. In TigerGraph, the graph is created implicitly when the first schema element (vertex or edge type) is created. We'll create the schema which will create the graph if it doesn't exist. Uses VertexType.ENTITY.value as the single vertex type with multi-label support via labels SET. """ workspace_label = self._get_workspace_label() def _graph_exists(graph_name: str) -> bool: """Check if a graph exists using USE GRAPH command.""" try: result = self._conn.gsql(f"USE GRAPH {graph_name}") result_str = str(result).lower() # If the graph doesn't exist, USE GRAPH returns an error message error_patterns = [ "does not exist", "doesn't exist", "doesn't exist!", f"graph '{graph_name.lower()}' does not exist", ] for pattern in error_patterns: if pattern in result_str: return False return True except Exception as e: error_str = str(e).lower() if "does not exist" in error_str or "doesn't exist" in error_str: return False # If exception doesn't indicate "doesn't exist", assume it exists return True def _create_graph_and_schema(): """ Create vertex and edge types globally, then create graph with those types. According to TigerGraph docs: CREATE GRAPH Graph_Name (Vertex_Or_Edge_Type, Vertex_Or_Edge_Type...) This creates the graph with the specified types in one command. Uses the generic edge definitions list to create all vertex and edge types. """ # Collect all unique vertex types from edge definitions vertex_types_to_create = set() for edge_def in TigerGraphStorage._EDGE_DEFINITIONS: vertex_types_to_create.add(edge_def.from_vertex) vertex_types_to_create.add(edge_def.to_vertex) # Step 1: Create all vertex types globally (must exist before CREATE GRAPH) for vertex_type in vertex_types_to_create: # For now, all vertex types use the same schema # In the future, this could be made configurable per vertex type gsql_create_vertex = f"""CREATE VERTEX {vertex_type.value} ( PRIMARY_ID entity_id STRING, labels SET, entity_type STRING, description STRING, keywords STRING, source_id STRING, file_path STRING, created_at INT, truncate STRING ) WITH primary_id_as_attribute="true" """ try: self._conn.gsql(gsql_create_vertex) logger.info( f"[{self.workspace}] Created vertex type '{vertex_type.value}'" ) except Exception as e: error_str = str(e).lower() if ( "used by another object" in error_str or "already exists" in error_str ): logger.debug( f"[{self.workspace}] Vertex type '{vertex_type.value}' already exists" ) else: logger.error( f"[{self.workspace}] Failed to create vertex type: {e}" ) raise # Step 2: Create all edge types globally (must exist before CREATE GRAPH) # Each edge explicitly references the vertex types it connects for edge_def in TigerGraphStorage._EDGE_DEFINITIONS: gsql_create_edge = f"""CREATE UNDIRECTED EDGE {edge_def.edge_type.value} ( FROM {edge_def.from_vertex.value}, TO {edge_def.to_vertex.value}, weight FLOAT DEFAULT 1.0, description STRING, keywords STRING, source_id STRING, file_path STRING, created_at INT, truncate STRING ) """ try: self._conn.gsql(gsql_create_edge) logger.info( f"[{self.workspace}] Created edge type '{edge_def.edge_type.value}' " f"(FROM {edge_def.from_vertex.value} TO {edge_def.to_vertex.value})" ) except Exception as e: error_str = str(e).lower() if ( "used by another object" in error_str or "already exists" in error_str ): logger.debug( f"[{self.workspace}] Edge type '{edge_def.edge_type.value}' already exists" ) else: logger.error( f"[{self.workspace}] Failed to create edge type: {e}" ) raise # Step 3: Create graph with all types (or ensure types are in existing graph) graph_exists = _graph_exists(self._graph_name) # Build list of all types for CREATE GRAPH command all_types = [vt.value for vt in vertex_types_to_create] all_types.extend( [ed.edge_type.value for ed in TigerGraphStorage._EDGE_DEFINITIONS] ) types_str = ", ".join(all_types) if not graph_exists: # Create graph with all types in one command logger.info( f"[{self.workspace}] Creating graph '{self._graph_name}' with types: {types_str}" ) gsql_create_graph = f"CREATE GRAPH {self._graph_name} ({types_str})" try: self._conn.gsql(gsql_create_graph) logger.info( f"[{self.workspace}] Created graph '{self._graph_name}' with types" ) except Exception as e: logger.error(f"[{self.workspace}] Failed to create graph: {e}") raise else: # Graph exists - check if types are in it, add if missing using schema change job logger.info( f"[{self.workspace}] Graph '{self._graph_name}' exists. Checking if types are associated..." ) # Check current schema try: schema = self._conn.getSchema(force=True) vertex_types = [vt["Name"] for vt in schema.get("VertexTypes", [])] edge_types = [et["Name"] for et in schema.get("EdgeTypes", [])] # Build list of types to add types_to_add = [] for vertex_type in vertex_types_to_create: if vertex_type.value not in vertex_types: types_to_add.append(("VERTEX", vertex_type.value)) for edge_def in TigerGraphStorage._EDGE_DEFINITIONS: if edge_def.edge_type.value not in edge_types: types_to_add.append(("EDGE", edge_def.edge_type.value)) if types_to_add: # Use schema change job to add types to existing graph job_name = f"add_types_to_{self._graph_name}" # Build ADD statements with correct syntax: ADD VERTEX/EDGE ... to graph ... add_statements = [] for type_kind, type_name in types_to_add: add_statements.append( f" ADD {type_kind} {type_name} to graph {self._graph_name};" ) gsql_schema_change = f"""CREATE GLOBAL SCHEMA_CHANGE JOB {job_name} {{{chr(10).join(add_statements)}}} RUN GLOBAL SCHEMA_CHANGE JOB {job_name}""" try: # Drop job if it exists (cleanup from previous runs) try: self._conn.gsql(f"DROP JOB {job_name}") except Exception: pass # Job doesn't exist, which is fine # Create and run the schema change job self._conn.gsql(gsql_schema_change) logger.info( f"[{self.workspace}] Added {len(types_to_add)} type(s) to graph '{self._graph_name}' " f"using schema change job: {[t[1] for t in types_to_add]}" ) except Exception as e: error_str = str(e).lower() if ( "already" in error_str or "exist" in error_str or "added" in error_str ): logger.debug( f"[{self.workspace}] Types may already be in graph: {[t[1] for t in types_to_add]}" ) else: logger.warning( f"[{self.workspace}] Could not add types to graph using schema change job: {e}" ) # Fallback: try ALTER GRAPH for each type for type_kind, type_name in types_to_add: try: if type_kind == "VERTEX": gsql_alter = ( f"USE GRAPH {self._graph_name}\n" f"ALTER GRAPH {self._graph_name} ADD VERTEX {type_name}" ) else: # EDGE gsql_alter = ( f"USE GRAPH {self._graph_name}\n" f"ALTER GRAPH {self._graph_name} ADD UNDIRECTED EDGE {type_name}" ) self._conn.gsql(gsql_alter) logger.info( f"[{self.workspace}] Added {type_kind} '{type_name}' to graph (fallback method)" ) except Exception as fallback_error: logger.warning( f"[{self.workspace}] Could not add {type_kind} '{type_name}' to graph: {fallback_error}" ) else: logger.debug( f"[{self.workspace}] All types already in graph: " f"vertices={[vt.value for vt in vertex_types_to_create]}, " f"edges={[ed.edge_type.value for ed in TigerGraphStorage._EDGE_DEFINITIONS]}" ) except Exception as e: logger.warning( f"[{self.workspace}] Could not check/add types to graph: {e}" ) # Install GSQL queries for efficient operations self._install_queries(workspace_label) await asyncio.to_thread(_create_graph_and_schema) def _install_queries(self, workspace_label: str): """Install GSQL queries for efficient graph operations.""" try: vertex_type = VertexType.ENTITY.value edge_type = EdgeType.RELATES.value # Query to get popular labels by degree # This query counts edges per vertex and returns sorted by degree # Filters by workspace label in the labels SET popular_labels_query = f""" CREATE QUERY get_popular_labels_{workspace_label}(INT limit) FOR GRAPH {self._graph_name} {{ MapAccum @@degree_map; HeapAccum>(limit, f0 DESC, f1 ASC) @@top_labels; # Initialize all vertices with degree 0, filtered by workspace label Start = {{{vertex_type}.*}}; Start = SELECT v FROM Start:v WHERE v.entity_id != "" AND "{workspace_label}" IN v.labels ACCUM @@degree_map += (v.entity_id -> 0); # Count edges (both directions for undirected graph) Start = SELECT v FROM Start:v - ({edge_type}:e) - {vertex_type}:t WHERE v.entity_id != "" AND t.entity_id != "" AND "{workspace_label}" IN v.labels AND "{workspace_label}" IN t.labels ACCUM @@degree_map += (v.entity_id -> 1); # Build heap with degree and label, sorted by degree DESC, label ASC Start = SELECT v FROM Start:v WHERE v.entity_id != "" AND "{workspace_label}" IN v.labels POST-ACCUM INT degree = @@degree_map.get(v.entity_id), @@top_labels += Tuple2(degree, v.entity_id); # Extract labels from heap (already sorted) ListAccum @@result; FOREACH item IN @@top_labels DO @@result += item.f1; END; PRINT @@result; }} """ # Query to search labels with fuzzy matching # This query filters vertices by entity_id containing the search query # Filters by workspace label in the labels SET search_labels_query = f""" CREATE QUERY search_labels_{workspace_label}(STRING search_query, INT limit) FOR GRAPH {self._graph_name} {{ ListAccum @@matches; STRING query_lower = lower(search_query); Start = {{{vertex_type}.*}}; Start = SELECT v FROM Start:v WHERE v.entity_id != "" AND str_contains(lower(v.entity_id), query_lower) AND "{workspace_label}" IN v.labels ACCUM @@matches += v.entity_id; PRINT @@matches; }} """ # Try to install queries (drop first if they exist) try: # Drop existing queries if they exist try: self._conn.gsql(f"DROP QUERY get_popular_labels_{workspace_label}") except Exception: pass # Query doesn't exist, which is fine try: self._conn.gsql(f"DROP QUERY search_labels_{workspace_label}") except Exception: pass # Query doesn't exist, which is fine # Install new queries self._conn.gsql(popular_labels_query) self._conn.gsql(search_labels_query) logger.info( f"[{self.workspace}] Installed GSQL queries for workspace '{workspace_label}'" ) except Exception as e: logger.warning( f"[{self.workspace}] Could not install GSQL queries: {str(e)}. " "Will fall back to traversal-based methods." ) except Exception as e: logger.warning( f"[{self.workspace}] Error installing GSQL queries: {str(e)}. " "Will fall back to traversal-based methods." ) async def finalize(self): """Close the TigerGraph connection and release all resources""" async with get_graph_db_lock(): if self._conn: # TigerGraph connection doesn't have explicit close, but we can clear reference self._conn = None async def __aexit__(self, exc_type, exc, tb): """Ensure connection is closed when context manager exits""" await self.finalize() async def index_done_callback(self) -> None: # TigerGraph handles persistence automatically pass async def has_node(self, node_id: str) -> bool: """Check if a node exists in the graph.""" workspace_label = self._get_workspace_label() def _check_node(): try: # Use getVerticesById since entity_id is the PRIMARY_ID try: result = self._conn.getVerticesById( VertexType.ENTITY.value, node_id ) if isinstance(result, dict) and node_id in result: attrs = result[node_id].get("attributes", {}) labels = attrs.get("labels", set()) if isinstance(labels, set) and workspace_label in labels: return True return False except Exception: # Fallback: try with filter using double quotes escaped_node_id = self._escape_filter_value(node_id) result = self._conn.getVertices( VertexType.ENTITY.value, where=f'entity_id=="{escaped_node_id}"', limit=10, ) # Filter by workspace label in labels SET for vertex in result: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: return True return False except Exception as e: logger.error( f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" ) raise return await asyncio.to_thread(_check_node) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """Check if an edge exists between two nodes.""" # workspace_label = self._get_workspace_label() def _check_edge(): try: # Check both directions for undirected graph try: result1 = self._conn.getEdges( VertexType.ENTITY.value, source_node_id, EdgeType.RELATES.value, VertexType.ENTITY.value, target_node_id, limit=1, ) if result1 and len(result1) > 0: return True except Exception as e1: # Error code 602 means edge doesn't exist, which is fine error_str = str(e1).lower() if ( "602" not in str(e1) and "does not have an edge" not in error_str ): logger.debug( f"[{self.workspace}] Error checking edge from {source_node_id} to {target_node_id}: {str(e1)}" ) try: result2 = self._conn.getEdges( VertexType.ENTITY.value, target_node_id, EdgeType.RELATES.value, VertexType.ENTITY.value, source_node_id, limit=1, ) if result2 and len(result2) > 0: return True except Exception as e2: # Error code 602 means edge doesn't exist, which is fine error_str = str(e2).lower() if ( "602" not in str(e2) and "does not have an edge" not in error_str ): logger.debug( f"[{self.workspace}] Error checking edge from {target_node_id} to {source_node_id}: {str(e2)}" ) # No edge found in either direction return False except Exception as e: # For any other unexpected error, log and return False logger.debug( f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) return False return await asyncio.to_thread(_check_edge) async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its entity_id, return only node properties.""" workspace_label = self._get_workspace_label() def _get_node(): try: # Use getVerticesById since entity_id is the PRIMARY_ID # This avoids filter syntax issues and is more efficient try: result = self._conn.getVerticesById( VertexType.ENTITY.value, node_id ) except Exception: # If getVerticesById fails, try with filter using double quotes escaped_node_id = self._escape_filter_value(node_id) result = self._conn.getVertices( VertexType.ENTITY.value, where=f'entity_id=="{escaped_node_id}"', limit=10, ) # Filter by workspace label in labels SET # Note: TigerGraph returns labels as a list in JSON, not a set matching_vertices = [] if isinstance(result, dict): # getVerticesById returns a dict {vertex_id: {attributes: {...}}} for vertex_id, vertex_data in result.items(): if vertex_id == node_id: attrs = vertex_data.get("attributes", {}) labels = attrs.get("labels", []) # Handle both list and set (list from JSON, set from Python) if ( isinstance(labels, (list, set, tuple)) and workspace_label in labels ): matching_vertices.append({"attributes": attrs}) elif isinstance(result, list): # getVertices returns a list of vertex dicts for vertex in result: labels = vertex.get("attributes", {}).get("labels", []) # Handle both list and set (list from JSON, set from Python) if ( isinstance(labels, (list, set, tuple)) and workspace_label in labels ): matching_vertices.append(vertex) if len(matching_vertices) > 1: logger.warning( f"[{self.workspace}] Multiple nodes found with entity_id '{node_id}'. Using first node." ) if matching_vertices: node_data = matching_vertices[0]["attributes"].copy() # Convert labels to list if needed, and filter out workspace label if "labels" in node_data: labels = node_data["labels"] if isinstance(labels, (set, tuple)): labels_list = list(labels) else: labels_list = ( labels.copy() if isinstance(labels, list) else [] ) # Remove workspace label from labels list (similar to Memgraph) labels_list = [ label for label in labels_list if label != workspace_label ] node_data["labels"] = labels_list # Keep entity_id in the dict if "entity_id" not in node_data: node_data["entity_id"] = node_id return node_data return None except Exception as e: logger.error( f"[{self.workspace}] Error getting node for {node_id}: {str(e)}" ) raise return await asyncio.to_thread(_get_node) async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """Retrieve multiple nodes in batch.""" workspace_label = self._get_workspace_label() def _get_nodes_batch(): nodes = {} try: # TigerGraph doesn't have native batch query, so we query in parallel # For now, iterate through node_ids for node_id in node_ids: try: # Use getVerticesById for primary key lookup result = [] try: vertex_result = self._conn.getVerticesById( VertexType.ENTITY.value, node_id ) # getVerticesById returns {vertex_id: {attributes: {...}}} # The key might be node_id or might be formatted differently if isinstance(vertex_result, dict) and vertex_result: # Try to find the vertex by checking all keys for vid, vdata in vertex_result.items(): attrs = vdata.get("attributes", {}) # Verify this is the node we're looking for by checking entity_id if ( attrs.get("entity_id") == node_id or vid == node_id ): result.append({"attributes": attrs}) break # If getVerticesById returned empty dict or no match found, try filter if not result: escaped_node_id = self._escape_filter_value(node_id) result = self._conn.getVertices( VertexType.ENTITY.value, where=f'entity_id=="{escaped_node_id}"', limit=10, ) except Exception as e: # Fallback to filter with double quotes if getVerticesById raises exception logger.debug( f"[{self.workspace}] getVerticesById failed for {node_id}, trying filter: {e}" ) try: escaped_node_id = self._escape_filter_value(node_id) result = self._conn.getVertices( VertexType.ENTITY.value, where=f'entity_id=="{escaped_node_id}"', limit=10, ) except Exception as e2: logger.debug( f"[{self.workspace}] Filter also failed for {node_id}: {e2}" ) result = [] # Filter by workspace label in labels SET # Note: TigerGraph returns labels as a list in JSON, not a set if not result: logger.debug( f"[{self.workspace}] No vertex found for node_id '{node_id}'" ) else: for vertex in result: attrs = vertex.get("attributes", {}) labels = attrs.get("labels", []) # Handle both list and set (list from JSON, set from Python) if isinstance(labels, (list, set, tuple)): # Check if workspace label is in labels if workspace_label in labels: node_data = attrs.copy() # Convert labels to list and filter out workspace label if isinstance(labels, (set, tuple)): labels_list = list(labels) else: labels_list = labels.copy() labels_list = [ label for label in labels_list if label != workspace_label ] node_data["labels"] = labels_list # Ensure entity_id is in the dict if "entity_id" not in node_data: node_data["entity_id"] = node_id nodes[node_id] = node_data break # Found matching node, move to next else: # Debug: log when workspace label doesn't match logger.debug( f"[{self.workspace}] Node '{node_id}' found but workspace label '{workspace_label}' not in labels: {labels}" ) else: logger.debug( f"[{self.workspace}] Node '{node_id}' has invalid labels format: {type(labels)}, value: {labels}" ) except Exception as e: logger.warning( f"[{self.workspace}] Error getting node {node_id}: {str(e)}" ) return nodes except Exception as e: logger.error(f"[{self.workspace}] Error in batch get nodes: {str(e)}") raise return await asyncio.to_thread(_get_nodes_batch) async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node.""" workspace_label = self._get_workspace_label() def _get_degree(): try: # TigerGraph's getEdges doesn't support '*' as wildcard # Instead, we use getEdges without target vertex ID to get all outgoing edges # and then get all incoming edges by querying from other vertices # However, a simpler approach is to use getNodeNeighbors or get all edges # Method 1: Use getEdges with empty string for target (gets all outgoing edges) # Then we need to get incoming edges separately try: # Get outgoing edges (from this node to any target) # pyTigerGraph's getEdges signature: getEdges(sourceVertexType, sourceVertexId, edgeType, targetVertexType, targetVertexId) # When targetVertexId is not provided or is empty, it should return all edges # But let's use getNodeNeighbors if available, or fetch all edges from get_node_edges # Use get_node_edges which already handles both directions edges = self._get_node_edges_sync(node_id, workspace_label) if edges: # Count unique edges (avoid double counting for undirected graph) edge_pairs = set() for source, target in edges: # Normalize edge direction for undirected graph if source < target: edge_pairs.add((source, target)) else: edge_pairs.add((target, source)) return len(edge_pairs) return 0 except Exception as e: # Fallback: try to get edges using a different method logger.debug( f"[{self.workspace}] Error getting node degree via get_node_edges for {node_id}: {str(e)}" ) # Try direct getEdges call - but we need to know all possible targets # This is inefficient, so we'll return 0 as fallback return 0 except Exception as e: logger.error( f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}" ) raise return await asyncio.to_thread(_get_degree) def _get_node_edges_sync( self, source_node_id: str, workspace_label: str ) -> list[tuple[str, str]]: """Synchronous helper to get node edges (used by node_degree).""" try: # TigerGraph's getEdges doesn't support '*' as wildcard # Use getEdges with only source vertex (omitting target parameters) to get all outgoing edges # pyTigerGraph's getEdges signature allows omitting target parameters edges = [] edge_pairs = set() # To avoid duplicates for undirected graph # Get outgoing edges: call getEdges with only source vertex type and ID # When target parameters are omitted, it should return all edges from this vertex try: # Try calling getEdges with minimal parameters (source only) result1 = self._conn.getEdges( VertexType.ENTITY.value, source_node_id, EdgeType.RELATES.value, ) if isinstance(result1, list): for edge in result1: # Extract target ID from edge target_id = None if isinstance(edge, dict): target_id = edge.get("to_id") or edge.get("to") if isinstance(target_id, dict): target_id = target_id.get("v_id") or target_id.get("id") if target_id: # Normalize for undirected graph pair = ( (source_node_id, target_id) if source_node_id < target_id else (target_id, source_node_id) ) if pair not in edge_pairs: edge_pairs.add(pair) edges.append((source_node_id, target_id)) except Exception as e1: logger.debug( f"[{self.workspace}] Error getting outgoing edges for {source_node_id}: {str(e1)}" ) # For incoming edges, we need a different approach # Since we can't query with '*' as source, we'll use getNeighbors if available # or we'll need to query all vertices (inefficient but necessary) try: # Try getNeighbors method if available if hasattr(self._conn, "getNeighbors"): neighbors = self._conn.getNeighbors( VertexType.ENTITY.value, source_node_id, EdgeType.RELATES.value ) if isinstance(neighbors, list): for neighbor in neighbors: target_id = None if isinstance(neighbor, dict): target_id = neighbor.get("v_id") or neighbor.get("id") if target_id and target_id != source_node_id: # Normalize for undirected graph pair = ( (source_node_id, target_id) if source_node_id < target_id else (target_id, source_node_id) ) if pair not in edge_pairs: edge_pairs.add(pair) edges.append((source_node_id, target_id)) except Exception as e2: logger.debug( f"[{self.workspace}] Error getting neighbors for {source_node_id}: {str(e2)}" ) return edges except Exception as e: logger.debug( f"[{self.workspace}] Error in _get_node_edges_sync for {source_node_id}: {str(e)}" ) return [] async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """Retrieve the degree for multiple nodes in batch.""" degrees = {} for node_id in node_ids: degree = await self.node_degree(node_id) degrees[node_id] = degree return degrees async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes.""" src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) return int(src_degree) + int(trg_degree) async def edge_degrees_batch( self, edge_pairs: list[tuple[str, str]] ) -> dict[tuple[str, str], int]: """Calculate the combined degree for each edge in batch.""" # Collect unique node IDs unique_node_ids = {src for src, _ in edge_pairs} unique_node_ids.update({tgt for _, tgt in edge_pairs}) # Get degrees for all nodes degrees = await self.node_degrees_batch(list(unique_node_ids)) # Sum up degrees for each edge pair edge_degrees = {} for src, tgt in edge_pairs: edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) return edge_degrees async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get edge properties between two nodes.""" # workspace_label = self._get_workspace_label() def _get_edge(): try: # Check both directions for undirected graph result1 = self._conn.getEdges( VertexType.ENTITY.value, source_node_id, EdgeType.RELATES.value, VertexType.ENTITY.value, target_node_id, limit=2, ) result2 = self._conn.getEdges( VertexType.ENTITY.value, target_node_id, EdgeType.RELATES.value, VertexType.ENTITY.value, source_node_id, limit=2, ) if len(result1) > 1 or len(result2) > 1: logger.warning( f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." ) if result1: edge_attrs = result1[0].get("attributes", {}) # Ensure required keys exist with defaults required_keys = { "weight": 1.0, "source_id": None, "description": None, "keywords": None, "file_path": None, "created_at": None, "truncate": None, } for key, default_value in required_keys.items(): if key not in edge_attrs: edge_attrs[key] = default_value return edge_attrs elif result2: edge_attrs = result2[0].get("attributes", {}) # Ensure required keys exist with defaults required_keys = { "weight": 1.0, "source_id": None, "description": None, "keywords": None, "file_path": None, "created_at": None, "truncate": None, } for key, default_value in required_keys.items(): if key not in edge_attrs: edge_attrs[key] = default_value return edge_attrs return None except Exception as e: logger.error( f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) raise return await asyncio.to_thread(_get_edge) async def get_edges_batch( self, pairs: list[dict[str, str]] ) -> dict[tuple[str, str], dict]: """Retrieve edge properties for multiple (src, tgt) pairs.""" edges_dict = {} for pair in pairs: src = pair["src"] tgt = pair["tgt"] edge = await self.get_edge(src, tgt) if edge is not None: edges_dict[(src, tgt)] = edge else: # Set default edge properties edges_dict[(src, tgt)] = { "weight": 1.0, "source_id": None, "description": None, "keywords": None, "file_path": None, "created_at": None, "truncate": None, } return edges_dict async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node.""" workspace_label = self._get_workspace_label() def _get_node_edges(): try: # Use the same helper method as node_degree to avoid '*' wildcard issue edges = self._get_node_edges_sync(source_node_id, workspace_label) return edges if edges else None except Exception as e: logger.error( f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) raise return await asyncio.to_thread(_get_node_edges) async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: """Batch retrieve edges for multiple nodes.""" edges_dict = {} for node_id in node_ids: edges = await self.get_node_edges(node_id) edges_dict[node_id] = edges if edges is not None else [] return edges_dict @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((ConnectionError, OSError, Exception)), ) async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Upsert a node in the TigerGraph database.""" workspace_label = self._get_workspace_label() def _upsert_node(): try: # Make a copy to avoid modifying the original node_data_copy = node_data.copy() # Ensure entity_id is in node_data if "entity_id" not in node_data_copy: node_data_copy["entity_id"] = node_id # Ensure labels SET includes workspace and entity_type entity_type = node_data_copy.get("entity_type", "UNKNOWN") if "labels" not in node_data_copy: # Create labels set with workspace and entity_type labels_set = {workspace_label, entity_type} else: # Ensure labels is a set and includes workspace and entity_type if isinstance(node_data_copy["labels"], (list, tuple)): labels_set = set(node_data_copy["labels"]) elif isinstance(node_data_copy["labels"], set): labels_set = node_data_copy["labels"].copy() else: labels_set = {str(node_data_copy["labels"])} # Add workspace and entity_type to labels labels_set.add(workspace_label) labels_set.add(entity_type) # Convert set to list for JSON serialization (TigerGraph REST API expects list for SET) node_data_copy["labels"] = list(labels_set) # Upsert vertex self._conn.upsertVertex( VertexType.ENTITY.value, node_id, node_data_copy ) except Exception as e: logger.error( f"[{self.workspace}] Error during node upsert for {node_id}: {str(e)}" ) raise await asyncio.to_thread(_upsert_node) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((ConnectionError, OSError, Exception)), ) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """Upsert an edge and its properties between two nodes.""" workspace_label = self._get_workspace_label() def _upsert_edge(): try: # Ensure both nodes exist first # Check if source node exists using getVerticesById source_exists = False try: source_result = self._conn.getVerticesById( VertexType.ENTITY.value, source_node_id ) if ( isinstance(source_result, dict) and source_node_id in source_result ): attrs = source_result[source_node_id].get("attributes", {}) labels = attrs.get("labels", set()) if isinstance(labels, set) and workspace_label in labels: source_exists = True except Exception: pass # Node doesn't exist if not source_exists: # Create source node with minimal data and labels self._conn.upsertVertex( VertexType.ENTITY.value, source_node_id, { "entity_id": source_node_id, "labels": list( {workspace_label, "UNKNOWN"} ), # Convert to list for JSON "entity_type": "UNKNOWN", }, ) # Check if target node exists using getVerticesById target_exists = False try: target_result = self._conn.getVerticesById( VertexType.ENTITY.value, target_node_id ) if ( isinstance(target_result, dict) and target_node_id in target_result ): attrs = target_result[target_node_id].get("attributes", {}) labels = attrs.get("labels", set()) if isinstance(labels, set) and workspace_label in labels: target_exists = True except Exception: pass # Node doesn't exist if not target_exists: # Create target node with minimal data and labels self._conn.upsertVertex( VertexType.ENTITY.value, target_node_id, { "entity_id": target_node_id, "labels": list( {workspace_label, "UNKNOWN"} ), # Convert to list for JSON "entity_type": "UNKNOWN", }, ) # Upsert edge (undirected, so direction doesn't matter) self._conn.upsertEdge( VertexType.ENTITY.value, source_node_id, EdgeType.RELATES.value, VertexType.ENTITY.value, target_node_id, edge_data, ) except Exception as e: logger.error( f"[{self.workspace}] Error during edge upsert between {source_node_id} and {target_node_id}: {str(e)}" ) raise await asyncio.to_thread(_upsert_edge) async def get_knowledge_graph( self, node_label: str, max_depth: int = 3, max_nodes: int = None, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. """ # Get max_nodes from global_config if not provided if max_nodes is None: max_nodes = self.global_config.get("max_graph_nodes", 1000) else: max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) workspace_label = self._get_workspace_label() result = KnowledgeGraph() def _get_knowledge_graph(): try: if node_label == "*": # Get all nodes sorted by degree, filtered by workspace label # TigerGraph REST API doesn't support IN operator for SET attributes # So we fetch all and filter in Python all_vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, limit=max_nodes * 2, # Fetch more to account for filtering ) # Filter by workspace label in labels SET all_vertices = [] for vertex in all_vertices_raw: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: all_vertices.append(vertex) if len(all_vertices) >= max_nodes: break # For simplicity, take first max_nodes vertices # In a real implementation, you'd want to sort by degree vertices = all_vertices[:max_nodes] if len(all_vertices) > max_nodes: result.is_truncated = True # Build node and edge sets node_ids = [v["attributes"].get("entity_id") for v in vertices] node_ids = [nid for nid in node_ids if nid] # Get all edges between these nodes edges_data = [] for node_id in node_ids: try: node_edges = self._conn.getEdges( VertexType.ENTITY.value, node_id, EdgeType.RELATES.value, VertexType.ENTITY.value, "*", limit=10000, ) for edge in node_edges: target_id = edge.get("to_id") if target_id in node_ids: edges_data.append(edge) except Exception: continue # Build result for vertex in vertices: attrs = vertex.get("attributes", {}) entity_id = attrs.get("entity_id") if entity_id: result.nodes.append( KnowledgeGraphNode( id=entity_id, labels=[entity_id], properties=attrs, ) ) edge_ids_seen = set() for edge in edges_data: source_id = edge.get("from_id") target_id = edge.get("to_id") if source_id and target_id: edge_tuple = tuple(sorted([source_id, target_id])) if edge_tuple not in edge_ids_seen: edge_attrs = edge.get("attributes", {}) result.edges.append( KnowledgeGraphEdge( id=f"{source_id}-{target_id}", type=EdgeType.RELATES.value, source=source_id, target=target_id, properties=edge_attrs, ) ) edge_ids_seen.add(edge_tuple) else: # BFS traversal starting from node_label from collections import deque visited_nodes = set() visited_edges = set() queue = deque([(node_label, 0)]) while queue and len(visited_nodes) < max_nodes: current_id, depth = queue.popleft() if current_id in visited_nodes or depth > max_depth: continue # Get node (filter by workspace label in Python) try: # Use getVerticesById for primary key lookup try: vertex_result = self._conn.getVerticesById( VertexType.ENTITY.value, current_id ) vertices_raw = [] if ( isinstance(vertex_result, dict) and current_id in vertex_result ): vertices_raw.append( { "attributes": vertex_result[current_id].get( "attributes", {} ) } ) except Exception: # Fallback to filter with double quotes escaped_current_id = self._escape_filter_value( current_id ) vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, where=f'entity_id=="{escaped_current_id}"', limit=10, ) # Filter by workspace label in labels SET vertices = [] for v in vertices_raw: labels = v.get("attributes", {}).get("labels", set()) if ( isinstance(labels, set) and workspace_label in labels ): vertices.append(v) if not vertices: continue vertex = vertices[0] attrs = vertex.get("attributes", {}) result.nodes.append( KnowledgeGraphNode( id=current_id, labels=[current_id], properties=attrs, ) ) visited_nodes.add(current_id) if depth < max_depth: # Get neighbors edges = self._conn.getEdges( VertexType.ENTITY.value, current_id, EdgeType.RELATES.value, VertexType.ENTITY.value, "*", limit=10000, ) for edge in edges: target_id = edge.get("to_id") if target_id and target_id not in visited_nodes: edge_tuple = tuple( sorted([current_id, target_id]) ) if edge_tuple not in visited_edges: edge_attrs = edge.get("attributes", {}) result.edges.append( KnowledgeGraphEdge( id=f"{current_id}-{target_id}", type=EdgeType.RELATES.value, source=current_id, target=target_id, properties=edge_attrs, ) ) visited_edges.add(edge_tuple) queue.append((target_id, depth + 1)) except Exception as e: logger.warning( f"[{self.workspace}] Error in BFS traversal for {current_id}: {str(e)}" ) continue if len(visited_nodes) >= max_nodes: result.is_truncated = True return result except Exception as e: logger.error( f"[{self.workspace}] Error in get_knowledge_graph: {str(e)}" ) raise result = await asyncio.to_thread(_get_knowledge_graph) logger.info( f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result async def get_all_labels(self) -> list[str]: """Get all existing node labels in the database.""" workspace_label = self._get_workspace_label() def _get_all_labels(): try: # TigerGraph REST API doesn't support IN operator for SET attributes # So we fetch all and filter in Python vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, limit=100000 ) # Filter by workspace label in labels SET vertices = [] for vertex in vertices_raw: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: vertices.append(vertex) labels = set() for vertex in vertices: entity_id = vertex.get("attributes", {}).get("entity_id") if entity_id: labels.add(entity_id) return sorted(list(labels)) except Exception as e: logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") raise return await asyncio.to_thread(_get_all_labels) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((ConnectionError, OSError, Exception)), ) async def delete_node(self, node_id: str) -> None: """Delete a node with the specified entity_id.""" workspace_label = self._get_workspace_label() def _delete_node(): try: # Check if node exists with workspace label first, then delete try: result = self._conn.getVerticesById( VertexType.ENTITY.value, node_id ) if isinstance(result, dict) and node_id in result: attrs = result[node_id].get("attributes", {}) labels = attrs.get("labels", set()) if isinstance(labels, set) and workspace_label in labels: # Delete this specific vertex escaped_node_id = self._escape_filter_value(node_id) self._conn.delVertices( VertexType.ENTITY.value, where=f'entity_id=="{escaped_node_id}"', ) except Exception: # Node doesn't exist or error occurred pass logger.debug( f"[{self.workspace}] Deleted node with entity_id '{node_id}'" ) except Exception as e: logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}") raise await asyncio.to_thread(_delete_node) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((ConnectionError, OSError, Exception)), ) async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes.""" for node in nodes: await self.delete_node(node) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((ConnectionError, OSError, Exception)), ) async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges.""" # workspace_label = self._get_workspace_label() def _delete_edge(source, target): try: # Delete edge in both directions self._conn.delEdges( VertexType.ENTITY.value, source, EdgeType.RELATES.value, VertexType.ENTITY.value, target, ) except Exception as e: logger.warning( f"[{self.workspace}] Error deleting edge from '{source}' to '{target}': {str(e)}" ) for source, target in edges: await asyncio.to_thread(_delete_edge, source, target) async def get_all_nodes(self) -> list[dict]: """Get all nodes in the graph.""" workspace_label = self._get_workspace_label() def _get_all_nodes(): try: # TigerGraph REST API doesn't support IN operator for SET attributes # So we fetch all and filter in Python vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, limit=100000 ) # Filter by workspace label in labels SET vertices = [] for vertex in vertices_raw: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: vertices.append(vertex) nodes = [] for vertex in vertices: attrs = vertex.get("attributes", {}) attrs["id"] = attrs.get("entity_id") # Convert labels SET to list and filter out workspace label if "labels" in attrs and isinstance(attrs["labels"], set): attrs["labels"] = [ label for label in attrs["labels"] if label != workspace_label ] nodes.append(attrs) return nodes except Exception as e: logger.error(f"[{self.workspace}] Error getting all nodes: {str(e)}") raise return await asyncio.to_thread(_get_all_nodes) async def get_all_edges(self) -> list[dict]: """Get all edges in the graph.""" workspace_label = self._get_workspace_label() def _get_all_edges(): try: # TigerGraph REST API doesn't support IN operator for SET attributes # So we fetch all and filter in Python vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, limit=100000 ) # Filter by workspace label in labels SET vertices = [] for vertex in vertices_raw: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: vertices.append(vertex) edges = [] processed_edges = set() for vertex in vertices: source_id = vertex.get("attributes", {}).get("entity_id") if not source_id: continue try: vertex_edges = self._conn.getEdges( VertexType.ENTITY.value, source_id, EdgeType.RELATES.value, VertexType.ENTITY.value, "*", limit=10000, ) for edge in vertex_edges: target_id = edge.get("to_id") edge_tuple = tuple(sorted([source_id, target_id])) if edge_tuple not in processed_edges: edge_attrs = edge.get("attributes", {}) edge_attrs["source"] = source_id edge_attrs["target"] = target_id edges.append(edge_attrs) processed_edges.add(edge_tuple) except Exception: continue return edges except Exception as e: logger.error(f"[{self.workspace}] Error getting all edges: {str(e)}") raise return await asyncio.to_thread(_get_all_edges) async def get_popular_labels(self, limit: int = 300) -> list[str]: """Get popular labels by node degree (most connected entities).""" workspace_label = self._get_workspace_label() def _get_popular_labels(): try: # Try to use installed GSQL query first query_name = f"get_popular_labels_{workspace_label}" try: result = self._conn.runInstalledQuery( query_name, params={"limit": limit} ) if result and len(result) > 0: # Extract labels from query result # Result format: [{"@@result": ["label1", "label2", ...]}] labels = [] for record in result: if "@@result" in record: labels.extend(record["@@result"]) # GSQL query already returns sorted labels (by degree DESC, label ASC) # Just return the limited results if labels: return labels[:limit] except Exception as query_error: logger.debug( f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. " "Falling back to traversal method." ) # Fallback to traversal method if GSQL query fails # TigerGraph REST API doesn't support IN operator for SET attributes # So we fetch all and filter in Python vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, limit=100000 ) # Filter by workspace label in labels SET vertices = [] for vertex in vertices_raw: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: vertices.append(vertex) node_degrees = {} for vertex in vertices: entity_id = vertex.get("attributes", {}).get("entity_id") if not entity_id: continue # Calculate degree try: edges = self._conn.getEdges( VertexType.ENTITY.value, entity_id, EdgeType.RELATES.value, VertexType.ENTITY.value, "*", limit=10000, ) # Count unique neighbors neighbors = set() for edge in edges: target_id = edge.get("to_id") if target_id: neighbors.add(target_id) node_degrees[entity_id] = len(neighbors) except Exception: node_degrees[entity_id] = 0 # Sort by degree descending, then by label ascending sorted_labels = sorted( node_degrees.items(), key=lambda x: (-x[1], x[0]), )[:limit] return [label for label, _ in sorted_labels] except Exception as e: logger.error( f"[{self.workspace}] Error getting popular labels: {str(e)}" ) raise return await asyncio.to_thread(_get_popular_labels) async def search_labels(self, query: str, limit: int = 50) -> list[str]: """Search labels with fuzzy matching.""" workspace_label = self._get_workspace_label() query_strip = query.strip() if not query_strip: return [] query_lower = query_strip.lower() is_chinese = self._is_chinese_text(query_strip) def _search_labels(): try: # Try to use installed GSQL query first query_name = f"search_labels_{workspace_label}" try: result = self._conn.runInstalledQuery( query_name, params={"search_query": query_strip, "limit": limit} ) if result and len(result) > 0: # Extract labels from query result labels = [] for record in result: if "@@matches" in record: labels.extend(record["@@matches"]) if labels: # GSQL query does basic filtering, we still need to score and sort # Score the results (exact match, prefix match, contains match) matches = [] for entity_id_str in labels: if is_chinese: # For Chinese, use direct contains if query_strip not in entity_id_str: continue # Calculate relevance score if entity_id_str == query_strip: score = 1000 elif entity_id_str.startswith(query_strip): score = 500 else: score = 100 - len(entity_id_str) else: # For non-Chinese, use case-insensitive contains entity_id_lower = entity_id_str.lower() if query_lower not in entity_id_lower: continue # Calculate relevance score if entity_id_lower == query_lower: score = 1000 elif entity_id_lower.startswith(query_lower): score = 500 else: score = 100 - len(entity_id_str) # Bonus for word boundary matches if ( f" {query_lower}" in entity_id_lower or f"_{query_lower}" in entity_id_lower ): score += 50 matches.append((entity_id_str, score)) # Sort by relevance score (desc) then alphabetically matches.sort(key=lambda x: (-x[1], x[0])) # Return top matches return [match[0] for match in matches[:limit]] except Exception as query_error: logger.debug( f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. " "Falling back to traversal method." ) # Fallback to traversal method if GSQL query fails # TigerGraph REST API doesn't support IN operator for SET attributes # So we fetch all and filter in Python vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, limit=100000 ) # Filter by workspace label in labels SET vertices = [] for vertex in vertices_raw: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: vertices.append(vertex) matches = [] for vertex in vertices: entity_id = vertex.get("attributes", {}).get("entity_id") if not entity_id: continue entity_id_str = str(entity_id) if is_chinese: # For Chinese, use direct contains if query_strip not in entity_id_str: continue # Calculate relevance score if entity_id_str == query_strip: score = 1000 elif entity_id_str.startswith(query_strip): score = 500 else: score = 100 - len(entity_id_str) else: # For non-Chinese, use case-insensitive contains entity_id_lower = entity_id_str.lower() if query_lower not in entity_id_lower: continue # Calculate relevance score if entity_id_lower == query_lower: score = 1000 elif entity_id_lower.startswith(query_lower): score = 500 else: score = 100 - len(entity_id_str) # Bonus for word boundary matches if ( f" {query_lower}" in entity_id_lower or f"_{query_lower}" in entity_id_lower ): score += 50 matches.append((entity_id_str, score)) # Sort by relevance score (desc) then alphabetically matches.sort(key=lambda x: (-x[1], x[0])) # Return top matches return [match[0] for match in matches[:limit]] except Exception as e: logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") raise return await asyncio.to_thread(_search_labels) async def drop(self) -> dict[str, str]: """Drop all data from current workspace storage and clean up resources.""" async with get_graph_db_lock(): workspace_label = self._get_workspace_label() try: def _drop(): # TigerGraph REST API doesn't support IN operator for SET attributes # So we fetch all vertices, filter by workspace label, and delete them vertices_raw = self._conn.getVertices( VertexType.ENTITY.value, limit=100000 ) # Filter by workspace label and collect entity_ids to delete entity_ids_to_delete = [] for vertex in vertices_raw: labels = vertex.get("attributes", {}).get("labels", set()) if isinstance(labels, set) and workspace_label in labels: entity_id = vertex.get("attributes", {}).get("entity_id") if entity_id: entity_ids_to_delete.append(entity_id) # Delete vertices by entity_id for entity_id in entity_ids_to_delete: try: escaped_entity_id = self._escape_filter_value(entity_id) self._conn.delVertices( VertexType.ENTITY.value, where=f'entity_id=="{escaped_entity_id}"', ) except Exception as e: logger.warning(f"Could not delete vertex {entity_id}: {e}") await asyncio.to_thread(_drop) return { "status": "success", "message": f"workspace '{workspace_label}' data dropped", } except Exception as e: logger.error( f"[{self.workspace}] Error dropping TigerGraph workspace '{workspace_label}': {e}" ) return {"status": "error", "message": str(e)}