diff --git a/config.ini.example b/config.ini.example index d3fff063..8a0a8bc8 100644 --- a/config.ini.example +++ b/config.ini.example @@ -35,3 +35,9 @@ ivfflat_lists = 100 [memgraph] uri = bolt://localhost:7687 + +[tigergraph] +uri = http://localhost:9000 +username = tigergraph +password = your_password +graph_name = lightrag diff --git a/env.example b/env.example index 43bc759b..b68280d2 100644 --- a/env.example +++ b/env.example @@ -395,6 +395,13 @@ MEMGRAPH_PASSWORD= MEMGRAPH_DATABASE=memgraph # MEMGRAPH_WORKSPACE=forced_workspace_name +### TigerGraph Configuration +TIGERGRAPH_URI=http://localhost:9000 +TIGERGRAPH_USERNAME=tigergraph +TIGERGRAPH_PASSWORD='your_password' +TIGERGRAPH_GRAPH_NAME=lightrag +# TIGERGRAPH_WORKSPACE=forced_workspace_name + ############################ ### Evaluation Configuration ############################ diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 8d42441a..42a32042 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = { "PGGraphStorage", "MongoGraphStorage", "MemgraphStorage", + "TigerGraphStorage", ], "required_methods": ["upsert_node", "upsert_edge"], }, @@ -53,6 +54,11 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], "MemgraphStorage": ["MEMGRAPH_URI"], + "TigerGraphStorage": [ + "TIGERGRAPH_URI", + "TIGERGRAPH_USERNAME", + "TIGERGRAPH_PASSWORD", + ], "AGEStorage": [ "AGE_POSTGRES_DB", "AGE_POSTGRES_USER", @@ -101,6 +107,7 @@ STORAGES = { "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", "MemgraphStorage": ".kg.memgraph_impl", + "TigerGraphStorage": ".kg.tigergraph_impl", } diff --git a/lightrag/kg/tigergraph_impl.py b/lightrag/kg/tigergraph_impl.py new file mode 100644 index 00000000..aef515e5 --- /dev/null +++ b/lightrag/kg/tigergraph_impl.py @@ -0,0 +1,1159 @@ +import os +import re +import asyncio +from dataclasses import dataclass +from typing import final +import configparser +from urllib.parse import urlparse + +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) + +import logging +from ..utils import logger +from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP +from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock +import pipmaster as pm + +if not pm.is_installed("pyTigerGraph"): + pm.install("pyTigerGraph") + +from pyTigerGraph import TigerGraphConnection # type: ignore + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + + +# Set pyTigerGraph logger level to ERROR to suppress warning logs +logging.getLogger("pyTigerGraph").setLevel(logging.ERROR) + + +@final +@dataclass +class TigerGraphStorage(BaseGraphStorage): + def __init__(self, namespace, global_config, embedding_func, workspace=None): + # Read env and override the arg if present + tigergraph_workspace = os.environ.get("TIGERGRAPH_WORKSPACE") + if tigergraph_workspace and tigergraph_workspace.strip(): + workspace = tigergraph_workspace + + # Default to 'base' when both arg and env are empty + if not workspace or not str(workspace).strip(): + workspace = "base" + + super().__init__( + namespace=namespace, + workspace=workspace, + global_config=global_config, + embedding_func=embedding_func, + ) + self._conn = None + self._graph_name = None + + def _get_workspace_label(self) -> str: + """Return workspace label (guaranteed non-empty during initialization)""" + return self.workspace + + def _is_chinese_text(self, text: str) -> bool: + """Check if text contains Chinese characters.""" + chinese_pattern = re.compile(r"[\u4e00-\u9fff]+") + return bool(chinese_pattern.search(text)) + + def _parse_uri(self, uri: str) -> tuple[str, int]: + """Parse URI to extract host and port.""" + parsed = urlparse(uri) + host = parsed.hostname or "localhost" + port = parsed.port or (443 if parsed.scheme == "https" else 80) + # Construct full URL with scheme + if not parsed.scheme: + scheme = "http" + else: + scheme = parsed.scheme + full_host = f"{scheme}://{host}:{port}" + return full_host, port + + async def initialize(self): + async with get_data_init_lock(): + URI = os.environ.get( + "TIGERGRAPH_URI", config.get("tigergraph", "uri", fallback=None) + ) + USERNAME = os.environ.get( + "TIGERGRAPH_USERNAME", + config.get("tigergraph", "username", fallback="tigergraph"), + ) + PASSWORD = os.environ.get( + "TIGERGRAPH_PASSWORD", + config.get("tigergraph", "password", fallback=""), + ) + GRAPH_NAME = os.environ.get( + "TIGERGRAPH_GRAPH_NAME", + config.get( + "tigergraph", + "graph_name", + fallback=re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace), + ), + ) + + if not URI: + raise ValueError( + "TIGERGRAPH_URI is required. Please set it in environment variables or config.ini" + ) + + # Parse URI to get host and port + host, port = self._parse_uri(URI) + self._graph_name = GRAPH_NAME + + # Initialize TigerGraph connection (synchronous) + def _init_connection(): + conn = TigerGraphConnection( + host=host, + username=USERNAME, + password=PASSWORD, + graphname=GRAPH_NAME, + ) + # Test connection + try: + conn.getVertices("Entity", limit=1) + except Exception as e: + # If graph doesn't exist, we'll create schema in _ensure_schema + logger.debug( + f"[{self.workspace}] Graph may not exist yet: {str(e)}" + ) + return conn + + # Run in thread pool to avoid blocking + self._conn = await asyncio.to_thread(_init_connection) + logger.info( + f"[{self.workspace}] Connected to TigerGraph at {host} (graph: {GRAPH_NAME})" + ) + + # Ensure schema exists + await self._ensure_schema() + + async def _ensure_schema(self): + """Ensure the graph schema exists with required vertex and edge types.""" + workspace_label = self._get_workspace_label() + + def _create_schema(): + # Create vertex type for entities (similar to Neo4j workspace label) + # Use workspace label as vertex type name + vertex_type = workspace_label + + # Check if vertex type exists + try: + schema = self._conn.getSchema(force=True) + vertex_types = [vt["Name"] for vt in schema["VertexTypes"]] + if vertex_type not in vertex_types: + # Create vertex type with entity_id as primary key + # All properties will be stored as attributes + gsql = f""" + CREATE VERTEX {vertex_type} ( + PRIMARY_ID entity_id STRING, + entity_type STRING, + description STRING, + keywords STRING, + source_id STRING + ) WITH primary_id_as_attribute="true" + """ + self._conn.gsql(gsql) + logger.info( + f"[{self.workspace}] Created vertex type '{vertex_type}'" + ) + except Exception as e: + # If vertex type creation fails, try to continue + logger.warning( + f"[{self.workspace}] Could not create vertex type '{vertex_type}': {str(e)}" + ) + + # Create edge type for relationships (undirected, similar to Neo4j) + edge_type = "DIRECTED" + try: + schema = self._conn.getSchema(force=True) + edge_types = [et["Name"] for et in schema["EdgeTypes"]] + if edge_type not in edge_types: + # Create undirected edge type + gsql = f""" + CREATE UNDIRECTED EDGE {edge_type} ( + FROM {vertex_type}, + TO {vertex_type}, + weight FLOAT DEFAULT 1.0, + description STRING, + keywords STRING, + source_id STRING + ) + """ + self._conn.gsql(gsql) + logger.info(f"[{self.workspace}] Created edge type '{edge_type}'") + except Exception as e: + logger.warning( + f"[{self.workspace}] Could not create edge type '{edge_type}': {str(e)}" + ) + + await asyncio.to_thread(_create_schema) + + async def finalize(self): + """Close the TigerGraph connection and release all resources""" + async with get_graph_db_lock(): + if self._conn: + # TigerGraph connection doesn't have explicit close, but we can clear reference + self._conn = None + + async def __aexit__(self, exc_type, exc, tb): + """Ensure connection is closed when context manager exits""" + await self.finalize() + + async def index_done_callback(self) -> None: + # TigerGraph handles persistence automatically + pass + + async def has_node(self, node_id: str) -> bool: + """Check if a node exists in the graph.""" + workspace_label = self._get_workspace_label() + + def _check_node(): + try: + result = self._conn.getVertices( + workspace_label, where=f'entity_id=="{node_id}"', limit=1 + ) + return len(result) > 0 + except Exception as e: + logger.error( + f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" + ) + raise + + return await asyncio.to_thread(_check_node) + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """Check if an edge exists between two nodes.""" + workspace_label = self._get_workspace_label() + + def _check_edge(): + try: + # Check both directions for undirected graph + result1 = self._conn.getEdges( + workspace_label, + source_node_id, + "DIRECTED", + workspace_label, + target_node_id, + limit=1, + ) + result2 = self._conn.getEdges( + workspace_label, + target_node_id, + "DIRECTED", + workspace_label, + source_node_id, + limit=1, + ) + return len(result1) > 0 or len(result2) > 0 + except Exception as e: + logger.error( + f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + raise + + return await asyncio.to_thread(_check_edge) + + async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its entity_id, return only node properties.""" + workspace_label = self._get_workspace_label() + + def _get_node(): + try: + result = self._conn.getVertices( + workspace_label, where=f'entity_id=="{node_id}"', limit=2 + ) + if len(result) > 1: + logger.warning( + f"[{self.workspace}] Multiple nodes found with entity_id '{node_id}'. Using first node." + ) + if result: + node_data = result[0]["attributes"] + # Remove entity_id from attributes if it's duplicated (it's the primary key) + if "entity_id" in node_data: + # Keep entity_id in the dict + pass + return node_data + return None + except Exception as e: + logger.error( + f"[{self.workspace}] Error getting node for {node_id}: {str(e)}" + ) + raise + + return await asyncio.to_thread(_get_node) + + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + """Retrieve multiple nodes in batch.""" + workspace_label = self._get_workspace_label() + + def _get_nodes_batch(): + nodes = {} + try: + # TigerGraph doesn't have native batch query, so we query in parallel + # For now, iterate through node_ids + for node_id in node_ids: + try: + result = self._conn.getVertices( + workspace_label, + where=f'entity_id=="{node_id}"', + limit=1, + ) + if result: + node_data = result[0]["attributes"] + nodes[node_id] = node_data + except Exception as e: + logger.warning( + f"[{self.workspace}] Error getting node {node_id}: {str(e)}" + ) + return nodes + except Exception as e: + logger.error(f"[{self.workspace}] Error in batch get nodes: {str(e)}") + raise + + return await asyncio.to_thread(_get_nodes_batch) + + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node.""" + workspace_label = self._get_workspace_label() + + def _get_degree(): + try: + # Get edges from this node (both directions for undirected graph) + result1 = self._conn.getEdges( + workspace_label, + node_id, + "DIRECTED", + workspace_label, + "*", + limit=10000, + ) + result2 = self._conn.getEdges( + workspace_label, + "*", + "DIRECTED", + workspace_label, + node_id, + limit=10000, + ) + # Count unique edges (avoid double counting) + edge_ids = set() + for edge in result1: + edge_id = edge.get("to_id", "") + edge_ids.add((node_id, edge_id)) + for edge in result2: + edge_id = edge.get("from_id", "") + edge_ids.add((edge_id, node_id)) + return len(edge_ids) + except Exception as e: + logger.error( + f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}" + ) + raise + + return await asyncio.to_thread(_get_degree) + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + """Retrieve the degree for multiple nodes in batch.""" + degrees = {} + for node_id in node_ids: + degree = await self.node_degree(node_id) + degrees[node_id] = degree + return degrees + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree (sum of relationships) of two nodes.""" + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + return int(src_degree) + int(trg_degree) + + async def edge_degrees_batch( + self, edge_pairs: list[tuple[str, str]] + ) -> dict[tuple[str, str], int]: + """Calculate the combined degree for each edge in batch.""" + # Collect unique node IDs + unique_node_ids = {src for src, _ in edge_pairs} + unique_node_ids.update({tgt for _, tgt in edge_pairs}) + + # Get degrees for all nodes + degrees = await self.node_degrees_batch(list(unique_node_ids)) + + # Sum up degrees for each edge pair + edge_degrees = {} + for src, tgt in edge_pairs: + edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) + return edge_degrees + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + """Get edge properties between two nodes.""" + workspace_label = self._get_workspace_label() + + def _get_edge(): + try: + # Check both directions for undirected graph + result1 = self._conn.getEdges( + workspace_label, + source_node_id, + "DIRECTED", + workspace_label, + target_node_id, + limit=2, + ) + result2 = self._conn.getEdges( + workspace_label, + target_node_id, + "DIRECTED", + workspace_label, + source_node_id, + limit=2, + ) + + if len(result1) > 1 or len(result2) > 1: + logger.warning( + f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." + ) + + if result1: + edge_attrs = result1[0].get("attributes", {}) + # Ensure required keys exist with defaults + required_keys = { + "weight": 1.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_attrs: + edge_attrs[key] = default_value + return edge_attrs + elif result2: + edge_attrs = result2[0].get("attributes", {}) + # Ensure required keys exist with defaults + required_keys = { + "weight": 1.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_attrs: + edge_attrs[key] = default_value + return edge_attrs + return None + except Exception as e: + logger.error( + f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + raise + + return await asyncio.to_thread(_get_edge) + + async def get_edges_batch( + self, pairs: list[dict[str, str]] + ) -> dict[tuple[str, str], dict]: + """Retrieve edge properties for multiple (src, tgt) pairs.""" + edges_dict = {} + for pair in pairs: + src = pair["src"] + tgt = pair["tgt"] + edge = await self.get_edge(src, tgt) + if edge is not None: + edges_dict[(src, tgt)] = edge + else: + # Set default edge properties + edges_dict[(src, tgt)] = { + "weight": 1.0, + "source_id": None, + "description": None, + "keywords": None, + } + return edges_dict + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """Retrieves all edges (relationships) for a particular node.""" + workspace_label = self._get_workspace_label() + + def _get_node_edges(): + try: + # Get edges from this node (both directions for undirected graph) + result1 = self._conn.getEdges( + workspace_label, + source_node_id, + "DIRECTED", + workspace_label, + "*", + limit=10000, + ) + result2 = self._conn.getEdges( + workspace_label, + "*", + "DIRECTED", + workspace_label, + source_node_id, + limit=10000, + ) + + edges = [] + edge_pairs = set() # To avoid duplicates + + # Process outgoing edges + for edge in result1: + target_id = edge.get("to_id") + if target_id: + pair = tuple(sorted([source_node_id, target_id])) + if pair not in edge_pairs: + edges.append((source_node_id, target_id)) + edge_pairs.add(pair) + + # Process incoming edges + for edge in result2: + source_id = edge.get("from_id") + if source_id: + pair = tuple(sorted([source_node_id, source_id])) + if pair not in edge_pairs: + edges.append((source_id, source_node_id)) + edge_pairs.add(pair) + + return edges if edges else None + except Exception as e: + logger.error( + f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" + ) + raise + + return await asyncio.to_thread(_get_node_edges) + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: + """Batch retrieve edges for multiple nodes.""" + edges_dict = {} + for node_id in node_ids: + edges = await self.get_node_edges(node_id) + edges_dict[node_id] = edges if edges is not None else [] + return edges_dict + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids.""" + workspace_label = self._get_workspace_label() + chunk_ids_set = set(chunk_ids) + + def _get_nodes_by_chunk_ids(): + try: + # Get all vertices and filter by source_id containing chunk_ids + result = self._conn.getVertices(workspace_label, limit=100000) + matching_nodes = [] + for vertex in result: + attrs = vertex.get("attributes", {}) + source_id = attrs.get("source_id") + if source_id: + node_source_ids = set(source_id.split(GRAPH_FIELD_SEP)) + if not node_source_ids.isdisjoint(chunk_ids_set): + attrs["id"] = attrs.get("entity_id") + matching_nodes.append(attrs) + return matching_nodes + except Exception as e: + logger.error( + f"[{self.workspace}] Error getting nodes by chunk_ids: {str(e)}" + ) + raise + + return await asyncio.to_thread(_get_nodes_by_chunk_ids) + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids.""" + workspace_label = self._get_workspace_label() + chunk_ids_set = set(chunk_ids) + + def _get_edges_by_chunk_ids(): + try: + # Get all edges and filter by source_id containing chunk_ids + # We need to get all vertices first, then their edges + result = self._conn.getVertices(workspace_label, limit=100000) + matching_edges = [] + processed_edges = set() + + for vertex in result: + source_id = vertex.get("attributes", {}).get("entity_id") + if not source_id: + continue + + # Get edges from this vertex + try: + edges = self._conn.getEdges( + workspace_label, + source_id, + "DIRECTED", + workspace_label, + "*", + limit=10000, + ) + for edge in edges: + edge_attrs = edge.get("attributes", {}) + edge_source_id = edge_attrs.get("source_id") + if edge_source_id: + edge_source_ids = set( + edge_source_id.split(GRAPH_FIELD_SEP) + ) + if not edge_source_ids.isdisjoint(chunk_ids_set): + target_id = edge.get("to_id") + edge_tuple = tuple(sorted([source_id, target_id])) + if edge_tuple not in processed_edges: + edge_attrs["source"] = source_id + edge_attrs["target"] = target_id + matching_edges.append(edge_attrs) + processed_edges.add(edge_tuple) + except Exception as e: + logger.warning( + f"[{self.workspace}] Error getting edges for vertex {source_id}: {str(e)}" + ) + continue + + return matching_edges + except Exception as e: + logger.error( + f"[{self.workspace}] Error getting edges by chunk_ids: {str(e)}" + ) + raise + + return await asyncio.to_thread(_get_edges_by_chunk_ids) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((ConnectionError, OSError, Exception)), + ) + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """Upsert a node in the TigerGraph database.""" + workspace_label = self._get_workspace_label() + + def _upsert_node(): + try: + # Ensure entity_id is in node_data + if "entity_id" not in node_data: + node_data["entity_id"] = node_id + + # Upsert vertex using upsertVertex + self._conn.upsertVertex(workspace_label, node_id, node_data) + except Exception as e: + logger.error( + f"[{self.workspace}] Error during node upsert for {node_id}: {str(e)}" + ) + raise + + await asyncio.to_thread(_upsert_node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((ConnectionError, OSError, Exception)), + ) + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: + """Upsert an edge and its properties between two nodes.""" + workspace_label = self._get_workspace_label() + + def _upsert_edge(): + try: + # Ensure both nodes exist first + # Check if source node exists + source_exists = self._conn.getVertices( + workspace_label, where=f'entity_id=="{source_node_id}"', limit=1 + ) + if not source_exists: + # Create source node with minimal data + self._conn.upsertVertex( + workspace_label, source_node_id, {"entity_id": source_node_id} + ) + + # Check if target node exists + target_exists = self._conn.getVertices( + workspace_label, where=f'entity_id=="{target_node_id}"', limit=1 + ) + if not target_exists: + # Create target node with minimal data + self._conn.upsertVertex( + workspace_label, target_node_id, {"entity_id": target_node_id} + ) + + # Upsert edge (undirected, so direction doesn't matter) + self._conn.upsertEdge( + workspace_label, + source_node_id, + "DIRECTED", + workspace_label, + target_node_id, + edge_data, + ) + except Exception as e: + logger.error( + f"[{self.workspace}] Error during edge upsert between {source_node_id} and {target_node_id}: {str(e)}" + ) + raise + + await asyncio.to_thread(_upsert_edge) + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = None, + ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + """ + # Get max_nodes from global_config if not provided + if max_nodes is None: + max_nodes = self.global_config.get("max_graph_nodes", 1000) + else: + max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) + + workspace_label = self._get_workspace_label() + result = KnowledgeGraph() + + def _get_knowledge_graph(): + try: + if node_label == "*": + # Get all nodes sorted by degree + all_vertices = self._conn.getVertices( + workspace_label, limit=max_nodes + ) + # For simplicity, take first max_nodes vertices + # In a real implementation, you'd want to sort by degree + vertices = all_vertices[:max_nodes] + if len(all_vertices) > max_nodes: + result.is_truncated = True + + # Build node and edge sets + node_ids = [v["attributes"].get("entity_id") for v in vertices] + node_ids = [nid for nid in node_ids if nid] + + # Get all edges between these nodes + edges_data = [] + for node_id in node_ids: + try: + node_edges = self._conn.getEdges( + workspace_label, + node_id, + "DIRECTED", + workspace_label, + "*", + limit=10000, + ) + for edge in node_edges: + target_id = edge.get("to_id") + if target_id in node_ids: + edges_data.append(edge) + except Exception: + continue + + # Build result + for vertex in vertices: + attrs = vertex.get("attributes", {}) + entity_id = attrs.get("entity_id") + if entity_id: + result.nodes.append( + KnowledgeGraphNode( + id=entity_id, + labels=[entity_id], + properties=attrs, + ) + ) + + edge_ids_seen = set() + for edge in edges_data: + source_id = edge.get("from_id") + target_id = edge.get("to_id") + if source_id and target_id: + edge_tuple = tuple(sorted([source_id, target_id])) + if edge_tuple not in edge_ids_seen: + edge_attrs = edge.get("attributes", {}) + result.edges.append( + KnowledgeGraphEdge( + id=f"{source_id}-{target_id}", + type="DIRECTED", + source=source_id, + target=target_id, + properties=edge_attrs, + ) + ) + edge_ids_seen.add(edge_tuple) + else: + # BFS traversal starting from node_label + from collections import deque + + visited_nodes = set() + visited_edges = set() + queue = deque([(node_label, 0)]) + + while queue and len(visited_nodes) < max_nodes: + current_id, depth = queue.popleft() + + if current_id in visited_nodes or depth > max_depth: + continue + + # Get node + try: + vertices = self._conn.getVertices( + workspace_label, + where=f'entity_id=="{current_id}"', + limit=1, + ) + if not vertices: + continue + + vertex = vertices[0] + attrs = vertex.get("attributes", {}) + result.nodes.append( + KnowledgeGraphNode( + id=current_id, + labels=[current_id], + properties=attrs, + ) + ) + visited_nodes.add(current_id) + + if depth < max_depth: + # Get neighbors + edges = self._conn.getEdges( + workspace_label, + current_id, + "DIRECTED", + workspace_label, + "*", + limit=10000, + ) + for edge in edges: + target_id = edge.get("to_id") + if target_id and target_id not in visited_nodes: + edge_tuple = tuple( + sorted([current_id, target_id]) + ) + if edge_tuple not in visited_edges: + edge_attrs = edge.get("attributes", {}) + result.edges.append( + KnowledgeGraphEdge( + id=f"{current_id}-{target_id}", + type="DIRECTED", + source=current_id, + target=target_id, + properties=edge_attrs, + ) + ) + visited_edges.add(edge_tuple) + queue.append((target_id, depth + 1)) + except Exception as e: + logger.warning( + f"[{self.workspace}] Error in BFS traversal for {current_id}: {str(e)}" + ) + continue + + if len(visited_nodes) >= max_nodes: + result.is_truncated = True + + return result + except Exception as e: + logger.error( + f"[{self.workspace}] Error in get_knowledge_graph: {str(e)}" + ) + raise + + result = await asyncio.to_thread(_get_knowledge_graph) + logger.info( + f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result + + async def get_all_labels(self) -> list[str]: + """Get all existing node labels in the database.""" + workspace_label = self._get_workspace_label() + + def _get_all_labels(): + try: + vertices = self._conn.getVertices(workspace_label, limit=100000) + labels = set() + for vertex in vertices: + entity_id = vertex.get("attributes", {}).get("entity_id") + if entity_id: + labels.add(entity_id) + return sorted(list(labels)) + except Exception as e: + logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") + raise + + return await asyncio.to_thread(_get_all_labels) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((ConnectionError, OSError, Exception)), + ) + async def delete_node(self, node_id: str) -> None: + """Delete a node with the specified entity_id.""" + workspace_label = self._get_workspace_label() + + def _delete_node(): + try: + self._conn.delVertices(workspace_label, where=f'entity_id=="{node_id}"') + logger.debug( + f"[{self.workspace}] Deleted node with entity_id '{node_id}'" + ) + except Exception as e: + logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}") + raise + + await asyncio.to_thread(_delete_node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((ConnectionError, OSError, Exception)), + ) + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes.""" + for node in nodes: + await self.delete_node(node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((ConnectionError, OSError, Exception)), + ) + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges.""" + workspace_label = self._get_workspace_label() + + def _delete_edge(source, target): + try: + # Delete edge in both directions + self._conn.delEdges( + workspace_label, + source, + "DIRECTED", + workspace_label, + target, + ) + except Exception as e: + logger.warning( + f"[{self.workspace}] Error deleting edge from '{source}' to '{target}': {str(e)}" + ) + + for source, target in edges: + await asyncio.to_thread(_delete_edge, source, target) + + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph.""" + workspace_label = self._get_workspace_label() + + def _get_all_nodes(): + try: + vertices = self._conn.getVertices(workspace_label, limit=100000) + nodes = [] + for vertex in vertices: + attrs = vertex.get("attributes", {}) + attrs["id"] = attrs.get("entity_id") + nodes.append(attrs) + return nodes + except Exception as e: + logger.error(f"[{self.workspace}] Error getting all nodes: {str(e)}") + raise + + return await asyncio.to_thread(_get_all_nodes) + + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph.""" + workspace_label = self._get_workspace_label() + + def _get_all_edges(): + try: + # Get all vertices first + vertices = self._conn.getVertices(workspace_label, limit=100000) + edges = [] + processed_edges = set() + + for vertex in vertices: + source_id = vertex.get("attributes", {}).get("entity_id") + if not source_id: + continue + + try: + vertex_edges = self._conn.getEdges( + workspace_label, + source_id, + "DIRECTED", + workspace_label, + "*", + limit=10000, + ) + for edge in vertex_edges: + target_id = edge.get("to_id") + edge_tuple = tuple(sorted([source_id, target_id])) + if edge_tuple not in processed_edges: + edge_attrs = edge.get("attributes", {}) + edge_attrs["source"] = source_id + edge_attrs["target"] = target_id + edges.append(edge_attrs) + processed_edges.add(edge_tuple) + except Exception: + continue + + return edges + except Exception as e: + logger.error(f"[{self.workspace}] Error getting all edges: {str(e)}") + raise + + return await asyncio.to_thread(_get_all_edges) + + async def get_popular_labels(self, limit: int = 300) -> list[str]: + """Get popular labels by node degree (most connected entities).""" + workspace_label = self._get_workspace_label() + + def _get_popular_labels(): + try: + # Get all vertices and calculate degrees + vertices = self._conn.getVertices(workspace_label, limit=100000) + node_degrees = {} + + for vertex in vertices: + entity_id = vertex.get("attributes", {}).get("entity_id") + if not entity_id: + continue + + # Calculate degree + try: + edges = self._conn.getEdges( + workspace_label, + entity_id, + "DIRECTED", + workspace_label, + "*", + limit=10000, + ) + # Count unique neighbors + neighbors = set() + for edge in edges: + target_id = edge.get("to_id") + if target_id: + neighbors.add(target_id) + node_degrees[entity_id] = len(neighbors) + except Exception: + node_degrees[entity_id] = 0 + + # Sort by degree descending, then by label ascending + sorted_labels = sorted( + node_degrees.items(), + key=lambda x: (-x[1], x[0]), + )[:limit] + + return [label for label, _ in sorted_labels] + except Exception as e: + logger.error( + f"[{self.workspace}] Error getting popular labels: {str(e)}" + ) + raise + + return await asyncio.to_thread(_get_popular_labels) + + async def search_labels(self, query: str, limit: int = 50) -> list[str]: + """Search labels with fuzzy matching.""" + workspace_label = self._get_workspace_label() + query_strip = query.strip() + if not query_strip: + return [] + + query_lower = query_strip.lower() + is_chinese = self._is_chinese_text(query_strip) + + def _search_labels(): + try: + # Get all vertices and filter + vertices = self._conn.getVertices(workspace_label, limit=100000) + matches = [] + + for vertex in vertices: + entity_id = vertex.get("attributes", {}).get("entity_id") + if not entity_id: + continue + + entity_id_str = str(entity_id) + if is_chinese: + # For Chinese, use direct contains + if query_strip not in entity_id_str: + continue + # Calculate relevance score + if entity_id_str == query_strip: + score = 1000 + elif entity_id_str.startswith(query_strip): + score = 500 + else: + score = 100 - len(entity_id_str) + else: + # For non-Chinese, use case-insensitive contains + entity_id_lower = entity_id_str.lower() + if query_lower not in entity_id_lower: + continue + # Calculate relevance score + if entity_id_lower == query_lower: + score = 1000 + elif entity_id_lower.startswith(query_lower): + score = 500 + else: + score = 100 - len(entity_id_str) + # Bonus for word boundary matches + if ( + f" {query_lower}" in entity_id_lower + or f"_{query_lower}" in entity_id_lower + ): + score += 50 + + matches.append((entity_id_str, score)) + + # Sort by relevance score (desc) then alphabetically + matches.sort(key=lambda x: (-x[1], x[0])) + + # Return top matches + return [match[0] for match in matches[:limit]] + except Exception as e: + logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") + raise + + return await asyncio.to_thread(_search_labels) + + async def drop(self) -> dict[str, str]: + """Drop all data from current workspace storage and clean up resources.""" + async with get_graph_db_lock(): + workspace_label = self._get_workspace_label() + try: + + def _drop(): + # Delete all vertices with this workspace label + self._conn.delVertices(workspace_label, where="") + + await asyncio.to_thread(_drop) + return { + "status": "success", + "message": f"workspace '{workspace_label}' data dropped", + } + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping TigerGraph workspace '{workspace_label}': {e}" + ) + return {"status": "error", "message": str(e)}