LightRAG/lightrag/kg/tigergraph_impl.py
2025-11-07 14:24:18 +01:00

1247 lines
50 KiB
Python

import os
import re
import asyncio
from dataclasses import dataclass
from typing import final
import configparser
from urllib.parse import urlparse
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import logging
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm
if not pm.is_installed("pyTigerGraph"):
pm.install("pyTigerGraph")
from pyTigerGraph import TigerGraphConnection # type: ignore
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Set pyTigerGraph logger level to ERROR to suppress warning logs
logging.getLogger("pyTigerGraph").setLevel(logging.ERROR)
@final
@dataclass
class TigerGraphStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func, workspace=None):
# Read env and override the arg if present
tigergraph_workspace = os.environ.get("TIGERGRAPH_WORKSPACE")
if tigergraph_workspace and tigergraph_workspace.strip():
workspace = tigergraph_workspace
# Default to 'base' when both arg and env are empty
if not workspace or not str(workspace).strip():
workspace = "base"
super().__init__(
namespace=namespace,
workspace=workspace,
global_config=global_config,
embedding_func=embedding_func,
)
self._conn = None
self._graph_name = None
def _get_workspace_label(self) -> str:
"""Return workspace label (guaranteed non-empty during initialization)"""
return self.workspace
def _is_chinese_text(self, text: str) -> bool:
"""Check if text contains Chinese characters."""
chinese_pattern = re.compile(r"[\u4e00-\u9fff]+")
return bool(chinese_pattern.search(text))
def _parse_uri(self, uri: str) -> tuple[str, int]:
"""Parse URI to extract host and port."""
parsed = urlparse(uri)
host = parsed.hostname or "localhost"
port = parsed.port or (443 if parsed.scheme == "https" else 80)
# Construct full URL with scheme
if not parsed.scheme:
scheme = "http"
else:
scheme = parsed.scheme
full_host = f"{scheme}://{host}:{port}"
return full_host, port
async def initialize(self):
async with get_data_init_lock():
URI = os.environ.get(
"TIGERGRAPH_URI", config.get("tigergraph", "uri", fallback=None)
)
USERNAME = os.environ.get(
"TIGERGRAPH_USERNAME",
config.get("tigergraph", "username", fallback="tigergraph"),
)
PASSWORD = os.environ.get(
"TIGERGRAPH_PASSWORD",
config.get("tigergraph", "password", fallback=""),
)
GRAPH_NAME = os.environ.get(
"TIGERGRAPH_GRAPH_NAME",
config.get(
"tigergraph",
"graph_name",
fallback=re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace),
),
)
if not URI:
raise ValueError(
"TIGERGRAPH_URI is required. Please set it in environment variables or config.ini"
)
# Parse URI to get host and port
host, port = self._parse_uri(URI)
self._graph_name = GRAPH_NAME
# Initialize TigerGraph connection (synchronous)
def _init_connection():
conn = TigerGraphConnection(
host=host,
username=USERNAME,
password=PASSWORD,
graphname=GRAPH_NAME,
)
# Test connection
try:
conn.getVertices("Entity", limit=1)
except Exception as e:
# If graph doesn't exist, we'll create schema in _ensure_schema
logger.debug(
f"[{self.workspace}] Graph may not exist yet: {str(e)}"
)
return conn
# Run in thread pool to avoid blocking
self._conn = await asyncio.to_thread(_init_connection)
logger.info(
f"[{self.workspace}] Connected to TigerGraph at {host} (graph: {GRAPH_NAME})"
)
# Ensure schema exists
await self._ensure_schema()
async def _ensure_schema(self):
"""Ensure the graph schema exists with required vertex and edge types."""
workspace_label = self._get_workspace_label()
def _create_schema():
# Create vertex type for entities (similar to Neo4j workspace label)
# Use workspace label as vertex type name
vertex_type = workspace_label
# Check if vertex type exists
try:
schema = self._conn.getSchema(force=True)
vertex_types = [vt["Name"] for vt in schema["VertexTypes"]]
if vertex_type not in vertex_types:
# Create vertex type with entity_id as primary key
# All properties will be stored as attributes
gsql = f"""
CREATE VERTEX {vertex_type} (
PRIMARY_ID entity_id STRING,
entity_type STRING,
description STRING,
keywords STRING,
source_id STRING
) WITH primary_id_as_attribute="true"
"""
self._conn.gsql(gsql)
logger.info(
f"[{self.workspace}] Created vertex type '{vertex_type}'"
)
except Exception as e:
# If vertex type creation fails, try to continue
logger.warning(
f"[{self.workspace}] Could not create vertex type '{vertex_type}': {str(e)}"
)
# Create edge type for relationships (undirected, similar to Neo4j)
edge_type = "DIRECTED"
try:
schema = self._conn.getSchema(force=True)
edge_types = [et["Name"] for et in schema["EdgeTypes"]]
if edge_type not in edge_types:
# Create undirected edge type
gsql = f"""
CREATE UNDIRECTED EDGE {edge_type} (
FROM {vertex_type},
TO {vertex_type},
weight FLOAT DEFAULT 1.0,
description STRING,
keywords STRING,
source_id STRING
)
"""
self._conn.gsql(gsql)
logger.info(f"[{self.workspace}] Created edge type '{edge_type}'")
except Exception as e:
logger.warning(
f"[{self.workspace}] Could not create edge type '{edge_type}': {str(e)}"
)
# Install GSQL queries for efficient operations
self._install_queries(workspace_label)
await asyncio.to_thread(_create_schema)
def _install_queries(self, workspace_label: str):
"""Install GSQL queries for efficient graph operations."""
try:
# Query to get popular labels by degree
# This query counts edges per vertex and returns sorted by degree
popular_labels_query = f"""
CREATE QUERY get_popular_labels_{workspace_label}(INT limit) FOR GRAPH {self._graph_name} {{
MapAccum<STRING, INT> @@degree_map;
HeapAccum<Tuple2<INT, STRING>>(limit, f0 DESC, f1 ASC) @@top_labels;
# Initialize all vertices with degree 0
Start = {{{workspace_label}}};
Start = SELECT v FROM Start:v
WHERE v.entity_id != ""
ACCUM @@degree_map += (v.entity_id -> 0);
# Count edges (both directions for undirected graph)
Start = SELECT v FROM Start:v - (DIRECTED:e) - {workspace_label}:t
WHERE v.entity_id != "" AND t.entity_id != ""
ACCUM @@degree_map += (v.entity_id -> 1);
# Build heap with degree and label, sorted by degree DESC, label ASC
Start = SELECT v FROM Start:v
WHERE v.entity_id != ""
POST-ACCUM
INT degree = @@degree_map.get(v.entity_id),
@@top_labels += Tuple2(degree, v.entity_id);
# Extract labels from heap (already sorted)
ListAccum<STRING> @@result;
FOREACH item IN @@top_labels DO
@@result += item.f1;
END;
PRINT @@result;
}}
"""
# Query to search labels with fuzzy matching
# This query filters vertices by entity_id containing the search query
search_labels_query = f"""
CREATE QUERY search_labels_{workspace_label}(STRING search_query, INT limit) FOR GRAPH {self._graph_name} {{
ListAccum<STRING> @@matches;
STRING query_lower = lower(search_query);
Start = {{{workspace_label}}};
Start = SELECT v FROM Start:v
WHERE v.entity_id != "" AND str_contains(lower(v.entity_id), query_lower)
ACCUM @@matches += v.entity_id;
PRINT @@matches;
}}
"""
# Try to install queries (drop first if they exist)
try:
# Drop existing queries if they exist
try:
self._conn.gsql(f"DROP QUERY get_popular_labels_{workspace_label}")
except Exception:
pass # Query doesn't exist, which is fine
try:
self._conn.gsql(f"DROP QUERY search_labels_{workspace_label}")
except Exception:
pass # Query doesn't exist, which is fine
# Install new queries
self._conn.gsql(popular_labels_query)
self._conn.gsql(search_labels_query)
logger.info(
f"[{self.workspace}] Installed GSQL queries for workspace '{workspace_label}'"
)
except Exception as e:
logger.warning(
f"[{self.workspace}] Could not install GSQL queries: {str(e)}. "
"Will fall back to traversal-based methods."
)
except Exception as e:
logger.warning(
f"[{self.workspace}] Error installing GSQL queries: {str(e)}. "
"Will fall back to traversal-based methods."
)
async def finalize(self):
"""Close the TigerGraph connection and release all resources"""
async with get_graph_db_lock():
if self._conn:
# TigerGraph connection doesn't have explicit close, but we can clear reference
self._conn = None
async def __aexit__(self, exc_type, exc, tb):
"""Ensure connection is closed when context manager exits"""
await self.finalize()
async def index_done_callback(self) -> None:
# TigerGraph handles persistence automatically
pass
async def has_node(self, node_id: str) -> bool:
"""Check if a node exists in the graph."""
workspace_label = self._get_workspace_label()
def _check_node():
try:
result = self._conn.getVertices(
workspace_label, where=f'entity_id=="{node_id}"', limit=1
)
return len(result) > 0
except Exception as e:
logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
raise
return await asyncio.to_thread(_check_node)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""Check if an edge exists between two nodes."""
workspace_label = self._get_workspace_label()
def _check_edge():
try:
# Check both directions for undirected graph
result1 = self._conn.getEdges(
workspace_label,
source_node_id,
"DIRECTED",
workspace_label,
target_node_id,
limit=1,
)
result2 = self._conn.getEdges(
workspace_label,
target_node_id,
"DIRECTED",
workspace_label,
source_node_id,
limit=1,
)
return len(result1) > 0 or len(result2) > 0
except Exception as e:
logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
raise
return await asyncio.to_thread(_check_edge)
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its entity_id, return only node properties."""
workspace_label = self._get_workspace_label()
def _get_node():
try:
result = self._conn.getVertices(
workspace_label, where=f'entity_id=="{node_id}"', limit=2
)
if len(result) > 1:
logger.warning(
f"[{self.workspace}] Multiple nodes found with entity_id '{node_id}'. Using first node."
)
if result:
node_data = result[0]["attributes"]
# Remove entity_id from attributes if it's duplicated (it's the primary key)
if "entity_id" in node_data:
# Keep entity_id in the dict
pass
return node_data
return None
except Exception as e:
logger.error(
f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
)
raise
return await asyncio.to_thread(_get_node)
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Retrieve multiple nodes in batch."""
workspace_label = self._get_workspace_label()
def _get_nodes_batch():
nodes = {}
try:
# TigerGraph doesn't have native batch query, so we query in parallel
# For now, iterate through node_ids
for node_id in node_ids:
try:
result = self._conn.getVertices(
workspace_label,
where=f'entity_id=="{node_id}"',
limit=1,
)
if result:
node_data = result[0]["attributes"]
nodes[node_id] = node_data
except Exception as e:
logger.warning(
f"[{self.workspace}] Error getting node {node_id}: {str(e)}"
)
return nodes
except Exception as e:
logger.error(f"[{self.workspace}] Error in batch get nodes: {str(e)}")
raise
return await asyncio.to_thread(_get_nodes_batch)
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node."""
workspace_label = self._get_workspace_label()
def _get_degree():
try:
# Get edges from this node (both directions for undirected graph)
result1 = self._conn.getEdges(
workspace_label,
node_id,
"DIRECTED",
workspace_label,
"*",
limit=10000,
)
result2 = self._conn.getEdges(
workspace_label,
"*",
"DIRECTED",
workspace_label,
node_id,
limit=10000,
)
# Count unique edges (avoid double counting)
edge_ids = set()
for edge in result1:
edge_id = edge.get("to_id", "")
edge_ids.add((node_id, edge_id))
for edge in result2:
edge_id = edge.get("from_id", "")
edge_ids.add((edge_id, node_id))
return len(edge_ids)
except Exception as e:
logger.error(
f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
)
raise
return await asyncio.to_thread(_get_degree)
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Retrieve the degree for multiple nodes in batch."""
degrees = {}
for node_id in node_ids:
degree = await self.node_degree(node_id)
degrees[node_id] = degree
return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes."""
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
return int(src_degree) + int(trg_degree)
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""Calculate the combined degree for each edge in batch."""
# Collect unique node IDs
unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes
degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair
edge_degrees = {}
for src, tgt in edge_pairs:
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes."""
workspace_label = self._get_workspace_label()
def _get_edge():
try:
# Check both directions for undirected graph
result1 = self._conn.getEdges(
workspace_label,
source_node_id,
"DIRECTED",
workspace_label,
target_node_id,
limit=2,
)
result2 = self._conn.getEdges(
workspace_label,
target_node_id,
"DIRECTED",
workspace_label,
source_node_id,
limit=2,
)
if len(result1) > 1 or len(result2) > 1:
logger.warning(
f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
)
if result1:
edge_attrs = result1[0].get("attributes", {})
# Ensure required keys exist with defaults
required_keys = {
"weight": 1.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in edge_attrs:
edge_attrs[key] = default_value
return edge_attrs
elif result2:
edge_attrs = result2[0].get("attributes", {})
# Ensure required keys exist with defaults
required_keys = {
"weight": 1.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in edge_attrs:
edge_attrs[key] = default_value
return edge_attrs
return None
except Exception as e:
logger.error(
f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
raise
return await asyncio.to_thread(_get_edge)
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""Retrieve edge properties for multiple (src, tgt) pairs."""
edges_dict = {}
for pair in pairs:
src = pair["src"]
tgt = pair["tgt"]
edge = await self.get_edge(src, tgt)
if edge is not None:
edges_dict[(src, tgt)] = edge
else:
# Set default edge properties
edges_dict[(src, tgt)] = {
"weight": 1.0,
"source_id": None,
"description": None,
"keywords": None,
}
return edges_dict
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node."""
workspace_label = self._get_workspace_label()
def _get_node_edges():
try:
# Get edges from this node (both directions for undirected graph)
result1 = self._conn.getEdges(
workspace_label,
source_node_id,
"DIRECTED",
workspace_label,
"*",
limit=10000,
)
result2 = self._conn.getEdges(
workspace_label,
"*",
"DIRECTED",
workspace_label,
source_node_id,
limit=10000,
)
edges = []
edge_pairs = set() # To avoid duplicates
# Process outgoing edges
for edge in result1:
target_id = edge.get("to_id")
if target_id:
pair = tuple(sorted([source_node_id, target_id]))
if pair not in edge_pairs:
edges.append((source_node_id, target_id))
edge_pairs.add(pair)
# Process incoming edges
for edge in result2:
source_id = edge.get("from_id")
if source_id:
pair = tuple(sorted([source_node_id, source_id]))
if pair not in edge_pairs:
edges.append((source_id, source_node_id))
edge_pairs.add(pair)
return edges if edges else None
except Exception as e:
logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
)
raise
return await asyncio.to_thread(_get_node_edges)
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""Batch retrieve edges for multiple nodes."""
edges_dict = {}
for node_id in node_ids:
edges = await self.get_node_edges(node_id)
edges_dict[node_id] = edges if edges is not None else []
return edges_dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((ConnectionError, OSError, Exception)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Upsert a node in the TigerGraph database."""
workspace_label = self._get_workspace_label()
def _upsert_node():
try:
# Ensure entity_id is in node_data
if "entity_id" not in node_data:
node_data["entity_id"] = node_id
# Upsert vertex using upsertVertex
self._conn.upsertVertex(workspace_label, node_id, node_data)
except Exception as e:
logger.error(
f"[{self.workspace}] Error during node upsert for {node_id}: {str(e)}"
)
raise
await asyncio.to_thread(_upsert_node)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((ConnectionError, OSError, Exception)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""Upsert an edge and its properties between two nodes."""
workspace_label = self._get_workspace_label()
def _upsert_edge():
try:
# Ensure both nodes exist first
# Check if source node exists
source_exists = self._conn.getVertices(
workspace_label, where=f'entity_id=="{source_node_id}"', limit=1
)
if not source_exists:
# Create source node with minimal data
self._conn.upsertVertex(
workspace_label, source_node_id, {"entity_id": source_node_id}
)
# Check if target node exists
target_exists = self._conn.getVertices(
workspace_label, where=f'entity_id=="{target_node_id}"', limit=1
)
if not target_exists:
# Create target node with minimal data
self._conn.upsertVertex(
workspace_label, target_node_id, {"entity_id": target_node_id}
)
# Upsert edge (undirected, so direction doesn't matter)
self._conn.upsertEdge(
workspace_label,
source_node_id,
"DIRECTED",
workspace_label,
target_node_id,
edge_data,
)
except Exception as e:
logger.error(
f"[{self.workspace}] Error during edge upsert between {source_node_id} and {target_node_id}: {str(e)}"
)
raise
await asyncio.to_thread(_upsert_edge)
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = None,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
"""
# Get max_nodes from global_config if not provided
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
workspace_label = self._get_workspace_label()
result = KnowledgeGraph()
def _get_knowledge_graph():
try:
if node_label == "*":
# Get all nodes sorted by degree
all_vertices = self._conn.getVertices(
workspace_label, limit=max_nodes
)
# For simplicity, take first max_nodes vertices
# In a real implementation, you'd want to sort by degree
vertices = all_vertices[:max_nodes]
if len(all_vertices) > max_nodes:
result.is_truncated = True
# Build node and edge sets
node_ids = [v["attributes"].get("entity_id") for v in vertices]
node_ids = [nid for nid in node_ids if nid]
# Get all edges between these nodes
edges_data = []
for node_id in node_ids:
try:
node_edges = self._conn.getEdges(
workspace_label,
node_id,
"DIRECTED",
workspace_label,
"*",
limit=10000,
)
for edge in node_edges:
target_id = edge.get("to_id")
if target_id in node_ids:
edges_data.append(edge)
except Exception:
continue
# Build result
for vertex in vertices:
attrs = vertex.get("attributes", {})
entity_id = attrs.get("entity_id")
if entity_id:
result.nodes.append(
KnowledgeGraphNode(
id=entity_id,
labels=[entity_id],
properties=attrs,
)
)
edge_ids_seen = set()
for edge in edges_data:
source_id = edge.get("from_id")
target_id = edge.get("to_id")
if source_id and target_id:
edge_tuple = tuple(sorted([source_id, target_id]))
if edge_tuple not in edge_ids_seen:
edge_attrs = edge.get("attributes", {})
result.edges.append(
KnowledgeGraphEdge(
id=f"{source_id}-{target_id}",
type="DIRECTED",
source=source_id,
target=target_id,
properties=edge_attrs,
)
)
edge_ids_seen.add(edge_tuple)
else:
# BFS traversal starting from node_label
from collections import deque
visited_nodes = set()
visited_edges = set()
queue = deque([(node_label, 0)])
while queue and len(visited_nodes) < max_nodes:
current_id, depth = queue.popleft()
if current_id in visited_nodes or depth > max_depth:
continue
# Get node
try:
vertices = self._conn.getVertices(
workspace_label,
where=f'entity_id=="{current_id}"',
limit=1,
)
if not vertices:
continue
vertex = vertices[0]
attrs = vertex.get("attributes", {})
result.nodes.append(
KnowledgeGraphNode(
id=current_id,
labels=[current_id],
properties=attrs,
)
)
visited_nodes.add(current_id)
if depth < max_depth:
# Get neighbors
edges = self._conn.getEdges(
workspace_label,
current_id,
"DIRECTED",
workspace_label,
"*",
limit=10000,
)
for edge in edges:
target_id = edge.get("to_id")
if target_id and target_id not in visited_nodes:
edge_tuple = tuple(
sorted([current_id, target_id])
)
if edge_tuple not in visited_edges:
edge_attrs = edge.get("attributes", {})
result.edges.append(
KnowledgeGraphEdge(
id=f"{current_id}-{target_id}",
type="DIRECTED",
source=current_id,
target=target_id,
properties=edge_attrs,
)
)
visited_edges.add(edge_tuple)
queue.append((target_id, depth + 1))
except Exception as e:
logger.warning(
f"[{self.workspace}] Error in BFS traversal for {current_id}: {str(e)}"
)
continue
if len(visited_nodes) >= max_nodes:
result.is_truncated = True
return result
except Exception as e:
logger.error(
f"[{self.workspace}] Error in get_knowledge_graph: {str(e)}"
)
raise
result = await asyncio.to_thread(_get_knowledge_graph)
logger.info(
f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def get_all_labels(self) -> list[str]:
"""Get all existing node labels in the database."""
workspace_label = self._get_workspace_label()
def _get_all_labels():
try:
vertices = self._conn.getVertices(workspace_label, limit=100000)
labels = set()
for vertex in vertices:
entity_id = vertex.get("attributes", {}).get("entity_id")
if entity_id:
labels.add(entity_id)
return sorted(list(labels))
except Exception as e:
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
raise
return await asyncio.to_thread(_get_all_labels)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((ConnectionError, OSError, Exception)),
)
async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified entity_id."""
workspace_label = self._get_workspace_label()
def _delete_node():
try:
self._conn.delVertices(workspace_label, where=f'entity_id=="{node_id}"')
logger.debug(
f"[{self.workspace}] Deleted node with entity_id '{node_id}'"
)
except Exception as e:
logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
raise
await asyncio.to_thread(_delete_node)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((ConnectionError, OSError, Exception)),
)
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes."""
for node in nodes:
await self.delete_node(node)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((ConnectionError, OSError, Exception)),
)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges."""
workspace_label = self._get_workspace_label()
def _delete_edge(source, target):
try:
# Delete edge in both directions
self._conn.delEdges(
workspace_label,
source,
"DIRECTED",
workspace_label,
target,
)
except Exception as e:
logger.warning(
f"[{self.workspace}] Error deleting edge from '{source}' to '{target}': {str(e)}"
)
for source, target in edges:
await asyncio.to_thread(_delete_edge, source, target)
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph."""
workspace_label = self._get_workspace_label()
def _get_all_nodes():
try:
vertices = self._conn.getVertices(workspace_label, limit=100000)
nodes = []
for vertex in vertices:
attrs = vertex.get("attributes", {})
attrs["id"] = attrs.get("entity_id")
nodes.append(attrs)
return nodes
except Exception as e:
logger.error(f"[{self.workspace}] Error getting all nodes: {str(e)}")
raise
return await asyncio.to_thread(_get_all_nodes)
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph."""
workspace_label = self._get_workspace_label()
def _get_all_edges():
try:
# Get all vertices first
vertices = self._conn.getVertices(workspace_label, limit=100000)
edges = []
processed_edges = set()
for vertex in vertices:
source_id = vertex.get("attributes", {}).get("entity_id")
if not source_id:
continue
try:
vertex_edges = self._conn.getEdges(
workspace_label,
source_id,
"DIRECTED",
workspace_label,
"*",
limit=10000,
)
for edge in vertex_edges:
target_id = edge.get("to_id")
edge_tuple = tuple(sorted([source_id, target_id]))
if edge_tuple not in processed_edges:
edge_attrs = edge.get("attributes", {})
edge_attrs["source"] = source_id
edge_attrs["target"] = target_id
edges.append(edge_attrs)
processed_edges.add(edge_tuple)
except Exception:
continue
return edges
except Exception as e:
logger.error(f"[{self.workspace}] Error getting all edges: {str(e)}")
raise
return await asyncio.to_thread(_get_all_edges)
async def get_popular_labels(self, limit: int = 300) -> list[str]:
"""Get popular labels by node degree (most connected entities)."""
workspace_label = self._get_workspace_label()
def _get_popular_labels():
try:
# Try to use installed GSQL query first
query_name = f"get_popular_labels_{workspace_label}"
try:
result = self._conn.runInstalledQuery(
query_name, params={"limit": limit}
)
if result and len(result) > 0:
# Extract labels from query result
# Result format: [{"@@result": ["label1", "label2", ...]}]
labels = []
for record in result:
if "@@result" in record:
labels.extend(record["@@result"])
# GSQL query already returns sorted labels (by degree DESC, label ASC)
# Just return the limited results
if labels:
return labels[:limit]
except Exception as query_error:
logger.debug(
f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. "
"Falling back to traversal method."
)
# Fallback to traversal method if GSQL query fails
# Get all vertices and calculate degrees
vertices = self._conn.getVertices(workspace_label, limit=100000)
node_degrees = {}
for vertex in vertices:
entity_id = vertex.get("attributes", {}).get("entity_id")
if not entity_id:
continue
# Calculate degree
try:
edges = self._conn.getEdges(
workspace_label,
entity_id,
"DIRECTED",
workspace_label,
"*",
limit=10000,
)
# Count unique neighbors
neighbors = set()
for edge in edges:
target_id = edge.get("to_id")
if target_id:
neighbors.add(target_id)
node_degrees[entity_id] = len(neighbors)
except Exception:
node_degrees[entity_id] = 0
# Sort by degree descending, then by label ascending
sorted_labels = sorted(
node_degrees.items(),
key=lambda x: (-x[1], x[0]),
)[:limit]
return [label for label, _ in sorted_labels]
except Exception as e:
logger.error(
f"[{self.workspace}] Error getting popular labels: {str(e)}"
)
raise
return await asyncio.to_thread(_get_popular_labels)
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
"""Search labels with fuzzy matching."""
workspace_label = self._get_workspace_label()
query_strip = query.strip()
if not query_strip:
return []
query_lower = query_strip.lower()
is_chinese = self._is_chinese_text(query_strip)
def _search_labels():
try:
# Try to use installed GSQL query first
query_name = f"search_labels_{workspace_label}"
try:
result = self._conn.runInstalledQuery(
query_name, params={"search_query": query_strip, "limit": limit}
)
if result and len(result) > 0:
# Extract labels from query result
labels = []
for record in result:
if "@@matches" in record:
labels.extend(record["@@matches"])
if labels:
# GSQL query does basic filtering, we still need to score and sort
# Score the results (exact match, prefix match, contains match)
matches = []
for entity_id_str in labels:
if is_chinese:
# For Chinese, use direct contains
if query_strip not in entity_id_str:
continue
# Calculate relevance score
if entity_id_str == query_strip:
score = 1000
elif entity_id_str.startswith(query_strip):
score = 500
else:
score = 100 - len(entity_id_str)
else:
# For non-Chinese, use case-insensitive contains
entity_id_lower = entity_id_str.lower()
if query_lower not in entity_id_lower:
continue
# Calculate relevance score
if entity_id_lower == query_lower:
score = 1000
elif entity_id_lower.startswith(query_lower):
score = 500
else:
score = 100 - len(entity_id_str)
# Bonus for word boundary matches
if (
f" {query_lower}" in entity_id_lower
or f"_{query_lower}" in entity_id_lower
):
score += 50
matches.append((entity_id_str, score))
# Sort by relevance score (desc) then alphabetically
matches.sort(key=lambda x: (-x[1], x[0]))
# Return top matches
return [match[0] for match in matches[:limit]]
except Exception as query_error:
logger.debug(
f"[{self.workspace}] GSQL query '{query_name}' not available or failed: {str(query_error)}. "
"Falling back to traversal method."
)
# Fallback to traversal method if GSQL query fails
# Get all vertices and filter
vertices = self._conn.getVertices(workspace_label, limit=100000)
matches = []
for vertex in vertices:
entity_id = vertex.get("attributes", {}).get("entity_id")
if not entity_id:
continue
entity_id_str = str(entity_id)
if is_chinese:
# For Chinese, use direct contains
if query_strip not in entity_id_str:
continue
# Calculate relevance score
if entity_id_str == query_strip:
score = 1000
elif entity_id_str.startswith(query_strip):
score = 500
else:
score = 100 - len(entity_id_str)
else:
# For non-Chinese, use case-insensitive contains
entity_id_lower = entity_id_str.lower()
if query_lower not in entity_id_lower:
continue
# Calculate relevance score
if entity_id_lower == query_lower:
score = 1000
elif entity_id_lower.startswith(query_lower):
score = 500
else:
score = 100 - len(entity_id_str)
# Bonus for word boundary matches
if (
f" {query_lower}" in entity_id_lower
or f"_{query_lower}" in entity_id_lower
):
score += 50
matches.append((entity_id_str, score))
# Sort by relevance score (desc) then alphabetically
matches.sort(key=lambda x: (-x[1], x[0]))
# Return top matches
return [match[0] for match in matches[:limit]]
except Exception as e:
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
raise
return await asyncio.to_thread(_search_labels)
async def drop(self) -> dict[str, str]:
"""Drop all data from current workspace storage and clean up resources."""
async with get_graph_db_lock():
workspace_label = self._get_workspace_label()
try:
def _drop():
# Delete all vertices with this workspace label
self._conn.delVertices(workspace_label, where="")
await asyncio.to_thread(_drop)
return {
"status": "success",
"message": f"workspace '{workspace_label}' data dropped",
}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping TigerGraph workspace '{workspace_label}': {e}"
)
return {"status": "error", "message": str(e)}