From a04c11a59828e4655d669d20194851e7ed34f136 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 6 Aug 2025 00:02:50 +0800 Subject: [PATCH] Remove deprecated storage --- lightrag/kg/__init__.py | 13 - lightrag/kg/deprecated/age_impl.py | 867 ----------------- lightrag/kg/deprecated/gremlin_impl.py | 686 ------------- lightrag/kg/deprecated/tidb_impl.py | 1230 ------------------------ 4 files changed, 2796 deletions(-) delete mode 100644 lightrag/kg/deprecated/age_impl.py delete mode 100644 lightrag/kg/deprecated/gremlin_impl.py delete mode 100644 lightrag/kg/deprecated/tidb_impl.py diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index b2a93e82..8d42441a 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -5,7 +5,6 @@ STORAGE_IMPLEMENTATIONS = { "RedisKVStorage", "PGKVStorage", "MongoKVStorage", - # "TiDBKVStorage", ], "required_methods": ["get_by_id", "upsert"], }, @@ -16,9 +15,6 @@ STORAGE_IMPLEMENTATIONS = { "PGGraphStorage", "MongoGraphStorage", "MemgraphStorage", - # "AGEStorage", - # "TiDBGraphStorage", - # "GremlinStorage", ], "required_methods": ["upsert_node", "upsert_edge"], }, @@ -31,7 +27,6 @@ STORAGE_IMPLEMENTATIONS = { "QdrantVectorDBStorage", "MongoVectorDBStorage", # "ChromaVectorDBStorage", - # "TiDBVectorDBStorage", ], "required_methods": ["query", "upsert"], }, @@ -52,20 +47,17 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "JsonKVStorage": [], "MongoKVStorage": [], "RedisKVStorage": ["REDIS_URI"], - # "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], # Graph Storage Implementations "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], "MemgraphStorage": ["MEMGRAPH_URI"], - # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", "AGE_POSTGRES_USER", "AGE_POSTGRES_PASSWORD", ], - # "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], "PGGraphStorage": [ "POSTGRES_USER", "POSTGRES_PASSWORD", @@ -75,7 +67,6 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "NanoVectorDBStorage": [], "MilvusVectorDBStorage": [], "ChromaVectorDBStorage": [], - # "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "FaissVectorDBStorage": [], "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None @@ -102,14 +93,10 @@ STORAGES = { "RedisKVStorage": ".kg.redis_impl", "RedisDocStatusStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", - # "TiDBKVStorage": ".kg.tidb_impl", - # "TiDBVectorDBStorage": ".kg.tidb_impl", - # "TiDBGraphStorage": ".kg.tidb_impl", "PGKVStorage": ".kg.postgres_impl", "PGVectorStorage": ".kg.postgres_impl", "AGEStorage": ".kg.age_impl", "PGGraphStorage": ".kg.postgres_impl", - # "GremlinStorage": ".kg.gremlin_impl", "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", diff --git a/lightrag/kg/deprecated/age_impl.py b/lightrag/kg/deprecated/age_impl.py deleted file mode 100644 index 097b7b0b..00000000 --- a/lightrag/kg/deprecated/age_impl.py +++ /dev/null @@ -1,867 +0,0 @@ -import asyncio -import inspect -import json -import os -import sys -from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Union, final -import pipmaster as pm -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge - -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from lightrag.utils import logger - -from ..base import BaseGraphStorage - -if sys.platform.startswith("win"): - import asyncio.windows_events - - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") - -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - -import psycopg # type: ignore -from psycopg.rows import namedtuple_row # type: ignore -from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore - - -class AGEQueryException(Exception): - """Exception for the AGE queries.""" - - def __init__(self, exception: Union[str, Dict]) -> None: - if isinstance(exception, dict): - self.message = exception["message"] if "message" in exception else "unknown" - self.details = exception["details"] if "details" in exception else "unknown" - else: - self.message = exception - self.details = "unknown" - - def get_message(self) -> str: - return self.message - - def get_details(self) -> Any: - return self.details - - -@final -@dataclass -class AGEStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name): - print("no preloading of graph with AGE in production") - - def __init__(self, namespace, global_config, embedding_func): - super().__init__( - namespace=namespace, - global_config=global_config, - embedding_func=embedding_func, - ) - self._driver = None - self._driver_lock = asyncio.Lock() - DB = os.environ["AGE_POSTGRES_DB"].replace("\\", "\\\\").replace("'", "\\'") - USER = os.environ["AGE_POSTGRES_USER"].replace("\\", "\\\\").replace("'", "\\'") - PASSWORD = ( - os.environ["AGE_POSTGRES_PASSWORD"] - .replace("\\", "\\\\") - .replace("'", "\\'") - ) - HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'") - PORT = os.environ.get("AGE_POSTGRES_PORT", "8529") - self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") - - connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" - - self._driver = AsyncConnectionPool(connection_string, open=False) - - return None - - async def close(self): - if self._driver: - await self._driver.close() - self._driver = None - - async def __aexit__(self, exc_type, exc, tb): - if self._driver: - await self._driver.close() - - @staticmethod - def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: - """ - Convert a record returned from an age query to a dictionary - - Args: - record (): a record from an age query result - - Returns: - Dict[str, Any]: a dictionary representation of the record where - the dictionary key is the field name and the value is the - value converted to a python type - """ - # result holder - d = {} - - # prebuild a mapping of vertex_id to vertex mappings to be used - # later to build edges - vertices = {} - for k in record._fields: - v = getattr(record, k) - # agtype comes back '{key: value}::type' which must be parsed - if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - if dtype == "vertex": - vertex = json.loads(v) - vertices[vertex["id"]] = vertex.get("properties") - - # iterate returned fields and parse appropriately - for k in record._fields: - v = getattr(record, k) - if isinstance(v, str) and "::" in v: - dtype = v.split("::")[-1] - v = v.split("::")[0] - else: - dtype = "" - - if dtype == "vertex": - vertex = json.loads(v) - field = json.loads(v).get("properties") - if not field: - field = {} - field["label"] = AGEStorage._decode_graph_label(vertex["label"]) - d[k] = field - # convert edge from id-label->id by replacing id with node information - # we only do this if the vertex was also returned in the query - # this is an attempt to be consistent with neo4j implementation - elif dtype == "edge": - edge = json.loads(v) - d[k] = ( - vertices.get(edge["start_id"], {}), - edge[ - "label" - ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" - vertices.get(edge["end_id"], {}), - ) - else: - d[k] = json.loads(v) if isinstance(v, str) else v - - return d - - @staticmethod - def _format_properties( - properties: Dict[str, Any], _id: Union[str, None] = None - ) -> str: - """ - Convert a dictionary of properties to a string representation that - can be used in a cypher query insert/merge statement. - - Args: - properties (Dict[str,str]): a dictionary containing node/edge properties - id (Union[str, None]): the id of the node or None if none exists - - Returns: - str: the properties dictionary as a properly formatted string - """ - props = [] - # wrap property key in backticks to escape - for k, v in properties.items(): - prop = f"`{k}`: {json.dumps(v)}" - props.append(prop) - if _id is not None and "id" not in properties: - props.append( - f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" - ) - return "{" + ", ".join(props) + "}" - - @staticmethod - def _encode_graph_label(label: str) -> str: - """ - Since AGE suports only alphanumerical labels, we will encode generic label as HEX string - - Args: - label (str): the original label - - Returns: - str: the encoded label - """ - return "x" + label.encode().hex() - - @staticmethod - def _decode_graph_label(encoded_label: str) -> str: - """ - Since AGE suports only alphanumerical labels, we will encode generic label as HEX string - - Args: - encoded_label (str): the encoded label - - Returns: - str: the decoded label - """ - return bytes.fromhex(encoded_label.removeprefix("x")).decode() - - @staticmethod - def _get_col_name(field: str, idx: int) -> str: - """ - Convert a cypher return field to a pgsql select field - If possible keep the cypher column name, but create a generic name if necessary - - Args: - field (str): a return field from a cypher query to be formatted for pgsql - idx (int): the position of the field in the return statement - - Returns: - str: the field to be used in the pgsql select statement - """ - # remove white space - field = field.strip() - # if an alias is provided for the field, use it - if " as " in field: - return field.split(" as ")[-1].strip() - # if the return value is an unnamed primitive, give it a generic name - if field.isnumeric() or field in ("true", "false", "null"): - return f"column_{idx}" - # otherwise return the value stripping out some common special chars - return field.replace("(", "_").replace(")", "") - - @staticmethod - def _wrap_query(query: str, graph_name: str, **params: str) -> str: - """ - Convert a cypher query to an Apache Age compatible - sql query by wrapping the cypher query in ag_catalog.cypher, - casting results to agtype and building a select statement - - Args: - query (str): a valid cypher query - graph_name (str): the name of the graph to query - params (dict): parameters for the query - - Returns: - str: an equivalent pgsql query - """ - - # pgsql template - template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ - {query} - $$) AS ({fields});""" - - # if there are any returned fields they must be added to the pgsql query - if "return" in query.lower(): - # parse return statement to identify returned fields - fields = ( - query.lower() - .split("return")[-1] - .split("distinct")[-1] - .split("order by")[0] - .split("skip")[0] - .split("limit")[0] - .split(",") - ) - - # raise exception if RETURN * is found as we can't resolve the fields - if "*" in [x.strip() for x in fields]: - raise ValueError( - "AGE graph does not support 'RETURN *'" - + " statements in Cypher queries" - ) - - # get pgsql formatted field names - fields = [ - AGEStorage._get_col_name(field, idx) for idx, field in enumerate(fields) - ] - - # build resulting pgsql relation - fields_str = ", ".join( - [field.split(".")[-1] + " agtype" for field in fields] - ) - - # if no return statement we still need to return a single field of type agtype - else: - fields_str = "a agtype" - - select_str = "*" - - return template.format( - graph_name=graph_name, - query=query.format(**params), - fields=fields_str, - projection=select_str, - ) - - async def _query(self, query: str, **params: str) -> List[Dict[str, Any]]: - """ - Query the graph by taking a cypher query, converting it to an - age compatible query, executing it and converting the result - - Args: - query (str): a cypher query to be executed - params (dict): parameters for the query - - Returns: - List[Dict[str, Any]]: a list of dictionaries containing the result set - """ - # convert cypher query to pgsql/age query - wrapped_query = self._wrap_query(query, self.graph_name, **params) - - await self._driver.open() - - # create graph if it doesn't exist - async with self._get_pool_connection() as conn: - async with conn.cursor() as curs: - try: - await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute(f"SELECT create_graph('{self.graph_name}')") - await conn.commit() - except ( - psycopg.errors.InvalidSchemaName, - psycopg.errors.UniqueViolation, - ): - await conn.rollback() - - # execute the query, rolling back on an error - async with self._get_pool_connection() as conn: - async with conn.cursor(row_factory=namedtuple_row) as curs: - try: - await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute(wrapped_query) - await conn.commit() - except psycopg.Error as e: - await conn.rollback() - raise AGEQueryException( - { - "message": f"Error executing graph query: {query.format(**params)}", - "detail": str(e), - } - ) from e - - data = await curs.fetchall() - if data is None: - result = [] - # decode records - else: - result = [AGEStorage._record_to_dict(d) for d in data] - - return result - - async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') - - query = """ - MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - single_result = (await self._query(query, **params))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - single_result["node_exists"], - ) - - return single_result["node_exists"] - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - - query = """ - MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) - RETURN COUNT(r) > 0 AS edge_exists - """ - params = { - "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), - } - single_result = (await self._query(query, **params))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - single_result["edge_exists"], - ) - return single_result["edge_exists"] - - async def get_node(self, node_id: str) -> dict[str, str] | None: - entity_name_label = node_id.strip('"') - query = """ - MATCH (n:`{label}`) RETURN n - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - record = await self._query(query, **params) - if record: - node = record[0] - node_dict = node["n"] - logger.debug( - "{%s}: query: {%s}, result: {%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - node_dict, - ) - return node_dict - return None - - async def node_degree(self, node_id: str) -> int: - entity_name_label = node_id.strip('"') - - query = """ - MATCH (n:`{label}`)-[]->(x) - RETURN count(x) AS total_edge_count - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - record = (await self._query(query, **params))[0] - if record: - edge_count = int(record["total_edge_count"]) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - edge_count, - ) - return edge_count - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') - src_degree = await self.node_degree(entity_name_label_source) - trg_degree = await self.node_degree(entity_name_label_target) - - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - logger.debug( - "{%s}:query:src_Degree+trg_degree:result:{%s}", - inspect.currentframe().f_code.co_name, - degrees, - ) - return degrees - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - - query = """ - MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) - RETURN properties(r) as edge_properties - LIMIT 1 - """ - params = { - "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), - } - record = await self._query(query, **params) - if record and record[0] and record[0]["edge_properties"]: - result = record[0]["edge_properties"] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query.format(**params), - result, - ) - return result - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ - node_label = source_node_id.strip('"') - - query = """ - MATCH (n:`{label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected - """ - params = {"label": AGEStorage._encode_graph_label(node_label)} - results = await self._query(query, **params) - edges = [] - for record in results: - source_node = record["n"] if record["n"] else None - connected_node = record["connected"] if record["connected"] else None - - source_label = ( - source_node["label"] if source_node and source_node["label"] else None - ) - target_label = ( - connected_node["label"] - if connected_node and connected_node["label"] - else None - ) - - if source_label and target_label: - edges.append((source_label, target_label)) - - return edges - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((AGEQueryException,)), - ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - """ - Upsert a node in the AGE database. - - Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties - """ - label = node_id.strip('"') - properties = node_data - - query = """ - MERGE (n:`{label}`) - SET n += {properties} - """ - params = { - "label": AGEStorage._encode_graph_label(label), - "properties": AGEStorage._format_properties(properties), - } - try: - await self._query(query, **params) - logger.debug( - "Upserted node with label '{%s}' and properties: {%s}", - label, - properties, - ) - except Exception as e: - logger.error("Error during upsert: {%s}", e) - raise - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((AGEQueryException,)), - ) - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - """ - Upsert an edge and its properties between two nodes identified by their labels. - - Args: - source_node_id (str): Label of the source node (used as identifier) - target_node_id (str): Label of the target node (used as identifier) - edge_data (dict): Dictionary of properties to set on the edge - """ - source_node_label = source_node_id.strip('"') - target_node_label = target_node_id.strip('"') - edge_properties = edge_data - - query = """ - MATCH (source:`{src_label}`) - WITH source - MATCH (target:`{tgt_label}`) - MERGE (source)-[r:DIRECTED]->(target) - SET r += {properties} - RETURN r - """ - params = { - "src_label": AGEStorage._encode_graph_label(source_node_label), - "tgt_label": AGEStorage._encode_graph_label(target_node_label), - "properties": AGEStorage._format_properties(edge_properties), - } - try: - await self._query(query, **params) - logger.debug( - "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", - source_node_label, - target_node_label, - edge_properties, - ) - except Exception as e: - logger.error("Error during edge upsert: {%s}", e) - raise - - @asynccontextmanager - async def _get_pool_connection(self, timeout: Optional[float] = None): - """Workaround for a psycopg_pool bug""" - - try: - connection = await self._driver.getconn(timeout=timeout) - except PoolTimeout: - await self._driver._add_connection(None) # workaround... - connection = await self._driver.getconn(timeout=timeout) - - try: - async with connection: - yield connection - finally: - await self._driver.putconn(connection) - - async def delete_node(self, node_id: str) -> None: - """Delete a node with the specified label - - Args: - node_id: The label of the node to delete - """ - entity_name_label = node_id.strip('"') - - query = """ - MATCH (n:`{label}`) - DETACH DELETE n - """ - params = {"label": AGEStorage._encode_graph_label(entity_name_label)} - try: - await self._query(query, **params) - logger.debug(f"Deleted node with label '{entity_name_label}'") - except Exception as e: - logger.error(f"Error during node deletion: {str(e)}") - raise - - async def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node labels to be deleted - """ - for node in nodes: - await self.delete_node(node) - - async def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - for source, target in edges: - entity_name_label_source = source.strip('"') - entity_name_label_target = target.strip('"') - - query = """ - MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) - DELETE r - """ - params = { - "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), - } - try: - await self._query(query, **params) - logger.debug( - f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'" - ) - except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") - raise - - async def get_all_labels(self) -> list[str]: - """Get all node labels in the database - - Returns: - ["label1", "label2", ...] # Alphabetically sorted label list - """ - query = """ - MATCH (n) - RETURN DISTINCT labels(n) AS node_labels - """ - results = await self._query(query) - - all_labels = [] - for record in results: - if record and "node_labels" in record: - for label in record["node_labels"]: - if label: - # Decode label - decoded_label = AGEStorage._decode_graph_label(label) - all_labels.append(decoded_label) - - # Remove duplicates and sort - return sorted(list(set(all_labels))) - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: - """ - Retrieve a connected subgraph of nodes where the label includes the specified 'node_label'. - Maximum number of nodes is constrained by the environment variable 'MAX_GRAPH_NODES' (default: 1000). - When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence (nodes containing the specified label string) - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes - - Args: - node_label: String to match in node labels (will match any node containing this string in its label) - max_depth: Maximum depth of the graph. Defaults to 5. - - Returns: - KnowledgeGraph: Complete connected subgraph for specified node - """ - max_graph_nodes = int(os.getenv("MAX_GRAPH_NODES", 1000)) - result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - - # Handle special case for "*" label - if node_label == "*": - # Query all nodes and sort by degree - query = """ - MATCH (n) - OPTIONAL MATCH (n)-[r]-() - WITH n, count(r) AS degree - ORDER BY degree DESC - LIMIT {max_nodes} - RETURN n, degree - """ - params = {"max_nodes": max_graph_nodes} - nodes_result = await self._query(query, **params) - - # Add nodes to result - node_ids = [] - for record in nodes_result: - if "n" in record: - node = record["n"] - node_id = str(node.get("id", "")) - if node_id not in seen_nodes: - node_properties = {k: v for k, v in node.items()} - node_label = node.get("label", "") - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_label], - properties=node_properties, - ) - ) - seen_nodes.add(node_id) - node_ids.append(node_id) - - # Query edges between these nodes - if node_ids: - edges_query = """ - MATCH (a)-[r]->(b) - WHERE a.id IN {node_ids} AND b.id IN {node_ids} - RETURN a, r, b - """ - edges_params = {"node_ids": node_ids} - edges_result = await self._query(edges_query, **edges_params) - - # Add edges to result - for record in edges_result: - if "r" in record and "a" in record and "b" in record: - source = record["a"].get("id", "") - target = record["b"].get("id", "") - edge_id = f"{source}-{target}" - if edge_id not in seen_edges: - edge_properties = {k: v for k, v in record["r"].items()} - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=source, - target=target, - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - else: - # For specific label, use partial matching - entity_name_label = node_label.strip('"') - encoded_label = AGEStorage._encode_graph_label(entity_name_label) - - # Find matching start nodes - start_query = """ - MATCH (n:`{label}`) - RETURN n - """ - start_params = {"label": encoded_label} - start_nodes = await self._query(start_query, **start_params) - - if not start_nodes: - logger.warning(f"No nodes found with label '{entity_name_label}'!") - return result - - # Traverse graph from each start node - for start_node_record in start_nodes: - if "n" in start_node_record: - # Use BFS to traverse graph - query = """ - MATCH (start:`{label}`) - CALL { - MATCH path = (start)-[*0..{max_depth}]->(n) - RETURN nodes(path) AS path_nodes, relationships(path) AS path_rels - } - RETURN DISTINCT path_nodes, path_rels - """ - params = {"label": encoded_label, "max_depth": max_depth} - results = await self._query(query, **params) - - # Extract nodes and edges from results - for record in results: - if "path_nodes" in record: - # Process nodes - for node in record["path_nodes"]: - node_id = str(node.get("id", "")) - if ( - node_id not in seen_nodes - and len(seen_nodes) < max_graph_nodes - ): - node_properties = {k: v for k, v in node.items()} - node_label = node.get("label", "") - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_label], - properties=node_properties, - ) - ) - seen_nodes.add(node_id) - - if "path_rels" in record: - # Process edges - for rel in record["path_rels"]: - source = str(rel.get("start_id", "")) - target = str(rel.get("end_id", "")) - edge_id = f"{source}-{target}" - if edge_id not in seen_edges: - edge_properties = {k: v for k, v in rel.items()} - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=rel.get("label", "DIRECTED"), - source=source, - target=target, - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - return result - - async def index_done_callback(self) -> None: - # AGES handles persistence automatically - pass - - async def drop(self) -> dict[str, str]: - """Drop the storage by removing all nodes and relationships in the graph. - - Returns: - dict[str, str]: Status of the operation with keys 'status' and 'message' - """ - try: - query = """ - MATCH (n) - DETACH DELETE n - """ - await self._query(query) - logger.info(f"Successfully dropped all data from graph {self.graph_name}") - return {"status": "success", "message": "graph data dropped"} - except Exception as e: - logger.error(f"Error dropping graph {self.graph_name}: {e}") - return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/deprecated/gremlin_impl.py b/lightrag/kg/deprecated/gremlin_impl.py deleted file mode 100644 index 32dbcc4e..00000000 --- a/lightrag/kg/deprecated/gremlin_impl.py +++ /dev/null @@ -1,686 +0,0 @@ -import asyncio -import inspect -import json -import os -import pipmaster as pm -from dataclasses import dataclass -from typing import Any, Dict, List, final - -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from lightrag.utils import logger - -from ..base import BaseGraphStorage - -if not pm.is_installed("gremlinpython"): - pm.install("gremlinpython") - -from gremlin_python.driver import client, serializer # type: ignore -from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore -from gremlin_python.driver.protocol import GremlinServerError # type: ignore - - -@final -@dataclass -class GremlinStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name): - print("no preloading of graph with Gremlin in production") - - def __init__(self, namespace, global_config, embedding_func): - super().__init__( - namespace=namespace, - global_config=global_config, - embedding_func=embedding_func, - ) - - self._driver = None - self._driver_lock = asyncio.Lock() - - USER = os.environ.get("GREMLIN_USER", "") - PASSWORD = os.environ.get("GREMLIN_PASSWORD", "") - HOST = os.environ["GREMLIN_HOST"] - PORT = int(os.environ["GREMLIN_PORT"]) - - # TraversalSource, a custom one has to be created manually, - # default it "g" - SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g") - - # All vertices will have graph={GRAPH} property, so that we can - # have several logical graphs for one source - GRAPH = GremlinStorage._to_value_map( - os.environ.get("GREMLIN_GRAPH", "LightRAG") - ) - - self.graph_name = GRAPH - - self._driver = client.Client( - f"ws://{HOST}:{PORT}/gremlin", - SOURCE, - username=USER, - password=PASSWORD, - message_serializer=serializer.GraphSONSerializersV3d0(), - transport_factory=lambda: AiohttpTransport(call_from_event_loop=True), - ) - - async def close(self): - if self._driver: - self._driver.close() - self._driver = None - - async def __aexit__(self, exc_type, exc, tb): - if self._driver: - self._driver.close() - - async def index_done_callback(self) -> None: - # Gremlin handles persistence automatically - pass - - @staticmethod - def _to_value_map(value: Any) -> str: - """Dump supported Python object as Gremlin valueMap""" - json_str = json.dumps(value, ensure_ascii=False, sort_keys=False) - parsed_str = json_str.replace("'", r"\'") - - # walk over the string and replace curly brackets with square brackets - # outside of strings, as well as replace double quotes with single quotes - # and "deescape" double quotes inside of strings - outside_str = True - escaped = False - remove_indices = [] - for i, c in enumerate(parsed_str): - if escaped: - # previous character was an "odd" backslash - escaped = False - if c == '"': - # we want to "deescape" double quotes: store indices to delete - remove_indices.insert(0, i - 1) - elif c == "\\": - escaped = True - elif c == '"': - outside_str = not outside_str - parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :] - elif c == "{" and outside_str: - parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :] - elif c == "}" and outside_str: - parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :] - for idx in remove_indices: - parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :] - return parsed_str - - @staticmethod - def _convert_properties(properties: Dict[str, Any]) -> str: - """Create chained .property() commands from properties dict""" - props = [] - for k, v in properties.items(): - prop_name = GremlinStorage._to_value_map(k) - props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})") - return "".join(props) - - @staticmethod - def _fix_name(name: str) -> str: - """Strip double quotes and format as a proper field name""" - name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'")) - - return name - - async def _query(self, query: str) -> List[Dict[str, Any]]: - """ - Query the Gremlin graph - - Args: - query (str): a query to be executed - - Returns: - List[Dict[str, Any]]: a list of dictionaries containing the result set - """ - - result = list(await asyncio.wrap_future(self._driver.submit_async(query))) - if result: - result = result[0] - - return result - - async def has_node(self, node_id: str) -> bool: - entity_name = GremlinStorage._fix_name(node_id) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .limit(1) - .count() - .project('has_node') - .by(__.choose(__.is(gt(0)), constant(true), constant(false))) - """ - result = await self._query(query) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - result[0]["has_node"], - ) - - return result[0]["has_node"] - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_source = GremlinStorage._fix_name(source_node_id) - entity_name_target = GremlinStorage._fix_name(target_node_id) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_source}) - .outE() - .inV().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_target}) - .limit(1) - .count() - .project('has_edge') - .by(__.choose(__.is(gt(0)), constant(true), constant(false))) - """ - result = await self._query(query) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - result[0]["has_edge"], - ) - - return result[0]["has_edge"] - - async def get_node(self, node_id: str) -> dict[str, str] | None: - entity_name = GremlinStorage._fix_name(node_id) - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .limit(1) - .project('properties') - .by(elementMap()) - """ - result = await self._query(query) - if result: - node = result[0] - node_dict = node["properties"] - logger.debug( - "{%s}: query: {%s}, result: {%s}", - inspect.currentframe().f_code.co_name, - query.format, - node_dict, - ) - return node_dict - - async def node_degree(self, node_id: str) -> int: - entity_name = GremlinStorage._fix_name(node_id) - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .outE() - .inV().has('graph', {self.graph_name}) - .count() - .project('total_edge_count') - .by() - """ - result = await self._query(query) - edge_count = result[0]["total_edge_count"] - - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - edge_count, - ) - - return edge_count - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - src_degree = await self.node_degree(src_id) - trg_degree = await self.node_degree(tgt_id) - - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - logger.debug( - "{%s}:query:src_Degree+trg_degree:result:{%s}", - inspect.currentframe().f_code.co_name, - degrees, - ) - return degrees - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - entity_name_source = GremlinStorage._fix_name(source_node_id) - entity_name_target = GremlinStorage._fix_name(target_node_id) - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_source}) - .outE() - .inV().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_target}) - .limit(1) - .project('edge_properties') - .by(__.bothE().elementMap()) - """ - result = await self._query(query) - if result: - edge_properties = result[0]["edge_properties"] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - edge_properties, - ) - return edge_properties - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - node_name = GremlinStorage._fix_name(source_node_id) - query = f"""g - .E() - .filter( - __.or( - __.outV().has('graph', {self.graph_name}) - .has('entity_name', {node_name}), - __.inV().has('graph', {self.graph_name}) - .has('entity_name', {node_name}) - ) - ) - .project('source_name', 'target_name') - .by(__.outV().values('entity_name')) - .by(__.inV().values('entity_name')) - """ - result = await self._query(query) - edges = [(res["source_name"], res["target_name"]) for res in result] - - return edges - - @retry( - stop=stop_after_attempt(10), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((GremlinServerError,)), - ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - """ - Upsert a node in the Gremlin graph. - - Args: - node_id: The unique identifier for the node (used as name) - node_data: Dictionary of node properties - """ - name = GremlinStorage._fix_name(node_id) - properties = GremlinStorage._convert_properties(node_data) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {name}) - .fold() - .coalesce( - __.unfold(), - __.addV('ENTITY') - .property('graph', {self.graph_name}) - .property('entity_name', {name}) - ) - {properties} - """ - - try: - await self._query(query) - logger.debug( - "Upserted node with name {%s} and properties: {%s}", - name, - properties, - ) - except Exception as e: - logger.error("Error during upsert: {%s}", e) - raise - - @retry( - stop=stop_after_attempt(10), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((GremlinServerError,)), - ) - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - """ - Upsert an edge and its properties between two nodes identified by their names. - - Args: - source_node_id (str): Name of the source node (used as identifier) - target_node_id (str): Name of the target node (used as identifier) - edge_data (dict): Dictionary of properties to set on the edge - """ - source_node_name = GremlinStorage._fix_name(source_node_id) - target_node_name = GremlinStorage._fix_name(target_node_id) - edge_properties = GremlinStorage._convert_properties(edge_data) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {source_node_name}).as('source') - .V().has('graph', {self.graph_name}) - .has('entity_name', {target_node_name}).as('target') - .coalesce( - __.select('source').outE('DIRECTED').where(__.inV().as('target')), - __.select('source').addE('DIRECTED').to(__.select('target')) - ) - .property('graph', {self.graph_name}) - {edge_properties} - """ - try: - await self._query(query) - logger.debug( - "Upserted edge from {%s} to {%s} with properties: {%s}", - source_node_name, - target_node_name, - edge_properties, - ) - except Exception as e: - logger.error("Error during edge upsert: {%s}", e) - raise - - async def delete_node(self, node_id: str) -> None: - """Delete a node with the specified entity_name - - Args: - node_id: The entity_name of the node to delete - """ - entity_name = GremlinStorage._fix_name(node_id) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .drop() - """ - try: - await self._query(query) - logger.debug( - "{%s}: Deleted node with entity_name '%s'", - inspect.currentframe().f_code.co_name, - entity_name, - ) - except Exception as e: - logger.error(f"Error during node deletion: {str(e)}") - raise - - async def get_all_labels(self) -> list[str]: - """ - Get all node entity_names in the graph - Returns: - [entity_name1, entity_name2, ...] # Alphabetically sorted entity_name list - """ - query = f"""g - .V().has('graph', {self.graph_name}) - .values('entity_name') - .dedup() - .order() - """ - try: - result = await self._query(query) - labels = result if result else [] - logger.debug( - "{%s}: Retrieved %d labels", - inspect.currentframe().f_code.co_name, - len(labels), - ) - return labels - except Exception as e: - logger.error(f"Error retrieving labels: {str(e)}") - return [] - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: - """ - Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. - Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - - Args: - node_label: Entity name of the starting node - max_depth: Maximum depth of the subgraph - - Returns: - KnowledgeGraph object containing nodes and edges - """ - result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - - # Get maximum number of graph nodes from environment variable, default is 1000 - MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - - entity_name = GremlinStorage._fix_name(node_label) - - # Handle special case for "*" label - if node_label == "*": - # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) - query = f"""g - .V().has('graph', {self.graph_name}) - .limit({MAX_GRAPH_NODES}) - .elementMap() - """ - nodes_result = await self._query(query) - - # Add nodes to result - for node_data in nodes_result: - node_id = node_data.get("entity_name", str(node_data.get("id", ""))) - if str(node_id) in seen_nodes: - continue - - # Create node with properties - node_properties = { - k: v for k, v in node_data.items() if k not in ["id", "label"] - } - - result.nodes.append( - KnowledgeGraphNode( - id=str(node_id), - labels=[str(node_id)], - properties=node_properties, - ) - ) - seen_nodes.add(str(node_id)) - - # Get and add edges - if nodes_result: - query = f"""g - .V().has('graph', {self.graph_name}) - .limit({MAX_GRAPH_NODES}) - .outE() - .inV().has('graph', {self.graph_name}) - .limit({MAX_GRAPH_NODES}) - .path() - .by(elementMap()) - .by(elementMap()) - .by(elementMap()) - """ - edges_result = await self._query(query) - - for path in edges_result: - if len(path) >= 3: # source -> edge -> target - source = path[0] - edge_data = path[1] - target = path[2] - - source_id = source.get("entity_name", str(source.get("id", ""))) - target_id = target.get("entity_name", str(target.get("id", ""))) - - edge_id = f"{source_id}-{target_id}" - if edge_id in seen_edges: - continue - - # Create edge with properties - edge_properties = { - k: v - for k, v in edge_data.items() - if k not in ["id", "label"] - } - - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source_id), - target=str(target_id), - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - else: - # Search for specific node and get its neighborhood - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name}) - .repeat(__.both().simplePath().dedup()) - .times({max_depth}) - .emit() - .dedup() - .limit({MAX_GRAPH_NODES}) - .elementMap() - """ - nodes_result = await self._query(query) - - # Add nodes to result - for node_data in nodes_result: - node_id = node_data.get("entity_name", str(node_data.get("id", ""))) - if str(node_id) in seen_nodes: - continue - - # Create node with properties - node_properties = { - k: v for k, v in node_data.items() if k not in ["id", "label"] - } - - result.nodes.append( - KnowledgeGraphNode( - id=str(node_id), - labels=[str(node_id)], - properties=node_properties, - ) - ) - seen_nodes.add(str(node_id)) - - # Get edges between the nodes in the result - if nodes_result: - node_ids = [ - n.get("entity_name", str(n.get("id", ""))) for n in nodes_result - ] - node_ids_query = ", ".join( - [GremlinStorage._to_value_map(nid) for nid in node_ids] - ) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', within({node_ids_query})) - .outE() - .where(inV().has('graph', {self.graph_name}) - .has('entity_name', within({node_ids_query}))) - .path() - .by(elementMap()) - .by(elementMap()) - .by(elementMap()) - """ - edges_result = await self._query(query) - - for path in edges_result: - if len(path) >= 3: # source -> edge -> target - source = path[0] - edge_data = path[1] - target = path[2] - - source_id = source.get("entity_name", str(source.get("id", ""))) - target_id = target.get("entity_name", str(target.get("id", ""))) - - edge_id = f"{source_id}-{target_id}" - if edge_id in seen_edges: - continue - - # Create edge with properties - edge_properties = { - k: v - for k, v in edge_data.items() - if k not in ["id", "label"] - } - - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source_id), - target=str(target_id), - properties=edge_properties, - ) - ) - seen_edges.add(edge_id) - - logger.info( - "Subgraph query successful | Node count: %d | Edge count: %d", - len(result.nodes), - len(result.edges), - ) - return result - - async def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node entity_names to be deleted - """ - for node in nodes: - await self.delete_node(node) - - async def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - for source, target in edges: - entity_name_source = GremlinStorage._fix_name(source) - entity_name_target = GremlinStorage._fix_name(target) - - query = f"""g - .V().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_source}) - .outE() - .where(inV().has('graph', {self.graph_name}) - .has('entity_name', {entity_name_target})) - .drop() - """ - try: - await self._query(query) - logger.debug( - "{%s}: Deleted edge from '%s' to '%s'", - inspect.currentframe().f_code.co_name, - entity_name_source, - entity_name_target, - ) - except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") - raise - - async def drop(self) -> dict[str, str]: - """Drop the storage by removing all nodes and relationships in the graph. - - This function deletes all nodes with the specified graph name property, - which automatically removes all associated edges. - - Returns: - dict[str, str]: Status of the operation with keys 'status' and 'message' - """ - try: - query = f"""g - .V().has('graph', {self.graph_name}) - .drop() - """ - await self._query(query) - logger.info(f"Successfully dropped all data from graph {self.graph_name}") - return {"status": "success", "message": "graph data dropped"} - except Exception as e: - logger.error(f"Error dropping graph {self.graph_name}: {e}") - return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/deprecated/tidb_impl.py b/lightrag/kg/deprecated/tidb_impl.py deleted file mode 100644 index 0d5dfca3..00000000 --- a/lightrag/kg/deprecated/tidb_impl.py +++ /dev/null @@ -1,1230 +0,0 @@ -import asyncio -import os -from dataclasses import dataclass, field -from typing import Any, Union, final -import time -import numpy as np - -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge - - -from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage -from ..namespace import NameSpace, is_namespace -from ..utils import logger - -import pipmaster as pm -import configparser - -if not pm.is_installed("pymysql"): - pm.install("pymysql") -if not pm.is_installed("sqlalchemy"): - pm.install("sqlalchemy") - -from sqlalchemy import create_engine, text # type: ignore - - -def sanitize_sensitive_info(data: dict) -> dict: - sanitized_data = data.copy() - sensitive_fields = [ - "password", - "user", - "host", - "database", - "port", - "ssl_verify_cert", - "ssl_verify_identity", - ] - for field_name in sensitive_fields: - if field_name in sanitized_data: - sanitized_data[field_name] = "***" - return sanitized_data - - -class TiDB: - def __init__(self, config, **kwargs): - self.host = config.get("host", None) - self.port = config.get("port", None) - self.user = config.get("user", None) - self.password = config.get("password", None) - self.database = config.get("database", None) - self.workspace = config.get("workspace", None) - connection_string = ( - f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" - f"?ssl_verify_cert=true&ssl_verify_identity=true" - ) - - try: - self.engine = create_engine(connection_string) - logger.info("Connected to TiDB database") - except Exception as e: - logger.error("Failed to connect to TiDB database") - logger.error(f"TiDB database error: {e}") - raise - - async def _migrate_timestamp_columns(self): - """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC""" - # Not implemented yet - pass - - async def check_tables(self): - # First create all tables - for k, v in TABLES.items(): - try: - await self.query(f"SELECT 1 FROM {k}".format(k=k)) - except Exception as e: - logger.error("Failed to check table in TiDB database") - logger.error(f"TiDB database error: {e}") - try: - await self.execute(v["ddl"]) - logger.info("Created table in TiDB database") - except Exception as e: - logger.error("Failed to create table in TiDB database") - logger.error(f"TiDB database error: {e}") - - # After all tables are created, try to migrate timestamp fields - try: - await self._migrate_timestamp_columns() - except Exception as e: - logger.error(f"TiDB, Failed to migrate timestamp columns: {e}") - # Don't raise exceptions, allow initialization process to continue - - async def query( - self, sql: str, params: dict = None, multirows: bool = False - ) -> Union[dict, None]: - if params is None: - params = {"workspace": self.workspace} - else: - params.update({"workspace": self.workspace}) - with self.engine.connect() as conn, conn.begin(): - try: - result = conn.execute(text(sql), params) - except Exception as e: - sanitized_params = sanitize_sensitive_info(params) - sanitized_error = sanitize_sensitive_info({"error": str(e)}) - logger.error( - f"Tidb database,\nsql:{sql},\nparams:{sanitized_params},\nerror:{sanitized_error}" - ) - raise - if multirows: - rows = result.all() - if rows: - data = [dict(zip(result.keys(), row)) for row in rows] - else: - data = [] - else: - row = result.first() - if row: - data = dict(zip(result.keys(), row)) - else: - data = None - return data - - async def execute(self, sql: str, data: list | dict = None): - # logger.info("go into TiDBDB execute method") - try: - with self.engine.connect() as conn, conn.begin(): - if data is None: - conn.execute(text(sql)) - else: - conn.execute(text(sql), parameters=data) - except Exception as e: - sanitized_data = sanitize_sensitive_info(data) if data else None - sanitized_error = sanitize_sensitive_info({"error": str(e)}) - logger.error( - f"Tidb database,\nsql:{sql},\ndata:{sanitized_data},\nerror:{sanitized_error}" - ) - raise - - -class ClientManager: - _instances: dict[str, Any] = {"db": None, "ref_count": 0} - _lock = asyncio.Lock() - - @staticmethod - def get_config() -> dict[str, Any]: - config = configparser.ConfigParser() - config.read("config.ini", "utf-8") - - return { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - - @classmethod - async def get_client(cls) -> TiDB: - async with cls._lock: - if cls._instances["db"] is None: - config = ClientManager.get_config() - db = TiDB(config) - await db.check_tables() - cls._instances["db"] = db - cls._instances["ref_count"] = 0 - cls._instances["ref_count"] += 1 - return cls._instances["db"] - - @classmethod - async def release_client(cls, db: TiDB): - async with cls._lock: - if db is not None: - if db is cls._instances["db"]: - cls._instances["ref_count"] -= 1 - if cls._instances["ref_count"] == 0: - cls._instances["db"] = None - - -@final -@dataclass -class TiDBKVStorage(BaseKVStorage): - db: TiDB = field(default=None) - - def __post_init__(self): - self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - ################ QUERY METHODS ################ - async def get_all(self) -> dict[str, Any]: - """Get all data from storage - - Returns: - Dictionary containing all stored data - """ - async with self._storage_lock: - return dict(self._data) - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Fetch doc_full data by id.""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"id": id} - response = await self.db.query(SQL, params) - return response if response else None - - # Query by id - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Fetch doc_chunks data by id""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - return await self.db.query(SQL, multirows=True) - - async def filter_keys(self, keys: set[str]) -> set[str]: - SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=namespace_to_table_name(self.namespace), - id_field=namespace_to_id(self.namespace), - ids=",".join([f"'{id}'" for id in keys]), - ) - try: - await self.db.query(SQL) - except Exception as e: - logger.error(f"Tidb database,\nsql:{SQL},\nkeys:{keys},\nerror:{e}") - res = await self.db.query(SQL, multirows=True) - if res: - exist_keys = [key["id"] for key in res] - data = set([s for s in keys if s not in exist_keys]) - else: - exist_keys = [] - data = set([s for s in keys if s not in exist_keys]) - return data - - ################ INSERT full_doc AND chunks ################ - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") - if not data: - return - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - list_data = [ - { - "__id__": k, - **{k1: v1 for k1, v1 in v.items()}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - - # Get current time as UNIX timestamp - current_time = int(time.time()) - - merge_sql = SQL_TEMPLATES["upsert_chunk"] - data = [] - for item in list_data: - data.append( - { - "id": item["__id__"], - "content": item["content"], - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content_vector": f"{item['__vector__'].tolist()}", - "workspace": self.db.workspace, - "timestamp": current_time, - } - ) - await self.db.execute(merge_sql, data) - - if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): - merge_sql = SQL_TEMPLATES["upsert_doc_full"] - data = [] - for k, v in self._data.items(): - data.append( - { - "id": k, - "content": v["content"], - "workspace": self.db.workspace, - } - ) - await self.db.execute(merge_sql, data) - return left_data - - async def index_done_callback(self) -> None: - # Ti handles persistence automatically - pass - - async def delete(self, ids: list[str]) -> None: - """Delete records with specified IDs from the storage. - - Args: - ids: List of record IDs to be deleted - """ - if not ids: - return - - try: - table_name = namespace_to_table_name(self.namespace) - id_field = namespace_to_id(self.namespace) - - if not table_name or not id_field: - logger.error(f"Unknown namespace for deletion: {self.namespace}") - return - - ids_list = ",".join([f"'{id}'" for id in ids]) - delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" - - await self.db.execute(delete_sql, {"workspace": self.db.workspace}) - logger.info( - f"Successfully deleted {len(ids)} records from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error deleting records from {self.namespace}: {e}") - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - -@final -@dataclass -class TiDBVectorDBStorage(BaseVectorStorage): - db: TiDB | None = field(default=None) - - def __post_init__(self): - self._client_file_name = os.path.join( - self.global_config["working_dir"], f"vdb_{self.namespace}.json" - ) - self._max_batch_size = self.global_config["embedding_batch_num"] - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") - if cosine_threshold is None: - raise ValueError( - "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" - ) - self.cosine_better_than_threshold = cosine_threshold - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - async def query( - self, query: str, top_k: int, ids: list[str] | None = None - ) -> list[dict[str, Any]]: - """Search from tidb vector""" - embeddings = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - embedding = embeddings[0] - - embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" - - params = { - "embedding_string": embedding_string, - "top_k": top_k, - "better_than_threshold": self.cosine_better_than_threshold, - } - - results = await self.db.query( - SQL_TEMPLATES[self.namespace], params=params, multirows=True - ) - if not results: - return [] - return results - - ###### INSERT entities And relationships ###### - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - if not data: - return - logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") - - # Get current time as UNIX timestamp - import time - - current_time = int(time.time()) - - list_data = [ - { - "id": k, - "timestamp": current_time, - **{k1: v1 for k1, v1 in v.items()}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embedding_tasks = [self.embedding_func(batch) for batch in batches] - embeddings_list = await asyncio.gather(*embedding_tasks) - - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["content_vector"] = embeddings[i] - - if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): - for item in list_data: - param = { - "id": item["id"], - "content": item["content"], - "tokens": item.get("tokens", 0), - "chunk_order_index": item.get("chunk_order_index", 0), - "full_doc_id": item.get("full_doc_id", ""), - "content_vector": f"{item['content_vector'].tolist()}", - "workspace": self.db.workspace, - "timestamp": item["timestamp"], - } - await self.db.execute(SQL_TEMPLATES["upsert_chunk"], param) - elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): - for item in list_data: - param = { - "id": item["id"], - "name": item["entity_name"], - "content": item["content"], - "content_vector": f"{item['content_vector'].tolist()}", - "workspace": self.db.workspace, - "timestamp": item["timestamp"], - } - await self.db.execute(SQL_TEMPLATES["upsert_entity"], param) - elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): - for item in list_data: - param = { - "id": item["id"], - "source_name": item["src_id"], - "target_name": item["tgt_id"], - "content": item["content"], - "content_vector": f"{item['content_vector'].tolist()}", - "workspace": self.db.workspace, - "timestamp": item["timestamp"], - } - await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param) - - async def delete(self, ids: list[str]) -> None: - """Delete vectors with specified IDs from the storage. - - Args: - ids: List of vector IDs to be deleted - """ - if not ids: - return - - table_name = namespace_to_table_name(self.namespace) - id_field = namespace_to_id(self.namespace) - - if not table_name or not id_field: - logger.error(f"Unknown namespace for vector deletion: {self.namespace}") - return - - ids_list = ",".join([f"'{id}'" for id in ids]) - delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" - - try: - await self.db.execute(delete_sql, {"workspace": self.db.workspace}) - logger.debug( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - - async def delete_entity(self, entity_name: str) -> None: - """Delete an entity by its name from the vector storage. - - Args: - entity_name: The name of the entity to delete - """ - try: - # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace AND name = :entity_name""" - - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} - ) - logger.debug(f"Successfully deleted entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") - - async def delete_entity_relation(self, entity_name: str) -> None: - """Delete all relations associated with an entity. - - Args: - entity_name: The name of the entity whose relations should be deleted - """ - try: - # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)""" - - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} - ) - logger.debug(f"Successfully deleted relations for entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting relations for entity {entity_name}: {e}") - - async def index_done_callback(self) -> None: - # Ti handles persistence automatically - pass - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get vector data by its ID - - Args: - id: The unique identifier of the vector - - Returns: - The vector data if found, or None if not found - """ - try: - # Determine which table to query based on namespace - if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: - sql_template = """ - SELECT entity_id as id, name as entity_name, entity_type, description, content, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_NODES - WHERE entity_id = :entity_id AND workspace = :workspace - """ - params = {"entity_id": id, "workspace": self.db.workspace} - elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: - sql_template = """ - SELECT relation_id as id, source_name as src_id, target_name as tgt_id, - keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_EDGES - WHERE relation_id = :relation_id AND workspace = :workspace - """ - params = {"relation_id": id, "workspace": self.db.workspace} - elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: - sql_template = """ - SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_DOC_CHUNKS - WHERE chunk_id = :chunk_id AND workspace = :workspace - """ - params = {"chunk_id": id, "workspace": self.db.workspace} - else: - logger.warning( - f"Namespace {self.namespace} not supported for get_by_id" - ) - return None - - result = await self.db.query(sql_template, params=params) - return result - except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") - return None - - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get multiple vector data by their IDs - - Args: - ids: List of unique identifiers - - Returns: - List of vector data objects that were found - """ - if not ids: - return [] - - try: - # Format IDs for SQL IN clause - ids_str = ", ".join([f"'{id}'" for id in ids]) - - # Determine which table to query based on namespace - if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: - sql_template = f""" - SELECT entity_id as id, name as entity_name, entity_type, description, content, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_NODES - WHERE entity_id IN ({ids_str}) AND workspace = :workspace - """ - elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: - sql_template = f""" - SELECT relation_id as id, source_name as src_id, target_name as tgt_id, - keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_GRAPH_EDGES - WHERE relation_id IN ({ids_str}) AND workspace = :workspace - """ - elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: - sql_template = f""" - SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id, - UNIX_TIMESTAMP(createtime) as created_at - FROM LIGHTRAG_DOC_CHUNKS - WHERE chunk_id IN ({ids_str}) AND workspace = :workspace - """ - else: - logger.warning( - f"Namespace {self.namespace} not supported for get_by_ids" - ) - return [] - - params = {"workspace": self.db.workspace} - results = await self.db.query(sql_template, params=params, multirows=True) - return results if results else [] - except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") - return [] - - -@final -@dataclass -class TiDBGraphStorage(BaseGraphStorage): - db: TiDB = field(default=None) - - def __post_init__(self): - self._max_batch_size = self.global_config["embedding_batch_num"] - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - #################### upsert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - entity_name = node_id - entity_type = node_data["entity_type"] - description = node_data["description"] - source_id = node_data["source_id"] - logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") - content = entity_name + description - contents = [content] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - sql = SQL_TEMPLATES["upsert_node"] - data = { - "workspace": self.db.workspace, - "name": entity_name, - "entity_type": entity_type, - "description": description, - "source_chunk_id": source_id, - "content": content, - "content_vector": f"{content_vector.tolist()}", - } - await self.db.execute(sql, data) - - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - source_name = source_node_id - target_name = target_node_id - weight = edge_data["weight"] - keywords = edge_data["keywords"] - description = edge_data["description"] - source_chunk_id = edge_data["source_id"] - logger.debug( - f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" - ) - - content = keywords + source_name + target_name + description - contents = [content] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["upsert_edge"] - data = { - "workspace": self.db.workspace, - "source_name": source_name, - "target_name": target_name, - "weight": weight, - "keywords": keywords, - "description": description, - "source_chunk_id": source_chunk_id, - "content": content, - "content_vector": f"{content_vector.tolist()}", - } - await self.db.execute(merge_sql, data) - - # Query - - async def has_node(self, node_id: str) -> bool: - sql = SQL_TEMPLATES["has_entity"] - param = {"name": node_id, "workspace": self.db.workspace} - has = await self.db.query(sql, param) - return has["cnt"] != 0 - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - sql = SQL_TEMPLATES["has_relationship"] - param = { - "source_name": source_node_id, - "target_name": target_node_id, - "workspace": self.db.workspace, - } - has = await self.db.query(sql, param) - return has["cnt"] != 0 - - async def node_degree(self, node_id: str) -> int: - sql = SQL_TEMPLATES["node_degree"] - param = {"name": node_id, "workspace": self.db.workspace} - result = await self.db.query(sql, param) - return result["cnt"] - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) - return degree - - async def get_node(self, node_id: str) -> dict[str, str] | None: - sql = SQL_TEMPLATES["get_node"] - param = {"name": node_id, "workspace": self.db.workspace} - return await self.db.query(sql, param) - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - sql = SQL_TEMPLATES["get_edge"] - param = { - "source_name": source_node_id, - "target_name": target_node_id, - "workspace": self.db.workspace, - } - return await self.db.query(sql, param) - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - sql = SQL_TEMPLATES["get_node_edges"] - param = {"source_name": source_node_id, "workspace": self.db.workspace} - res = await self.db.query(sql, param, multirows=True) - if res: - data = [(i["source_name"], i["target_name"]) for i in res] - return data - else: - return [] - - async def index_done_callback(self) -> None: - # Ti handles persistence automatically - pass - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - drop_sql = """ - DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace; - DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace; - """ - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "graph data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - async def delete_node(self, node_id: str) -> None: - """Delete a node and all its related edges - - Args: - node_id: The ID of the node to delete - """ - # First delete all edges related to this node - await self.db.execute( - SQL_TEMPLATES["delete_node_edges"], - {"name": node_id, "workspace": self.db.workspace}, - ) - - # Then delete the node itself - await self.db.execute( - SQL_TEMPLATES["delete_node"], - {"name": node_id, "workspace": self.db.workspace}, - ) - - logger.debug( - f"Node {node_id} and its related edges have been deleted from the graph" - ) - - async def get_all_labels(self) -> list[str]: - """Get all entity types (labels) in the database - - Returns: - List of labels sorted alphabetically - """ - result = await self.db.query( - SQL_TEMPLATES["get_all_labels"], - {"workspace": self.db.workspace}, - multirows=True, - ) - - if not result: - return [] - - # Extract all labels - return [item["label"] for item in result] - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: - """ - Get a connected subgraph of nodes matching the specified label - Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) - - Args: - node_label: The node label to match - max_depth: Maximum depth of the subgraph - - Returns: - KnowledgeGraph object containing nodes and edges - """ - result = KnowledgeGraph() - MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - - # Get matching nodes - if node_label == "*": - # Handle special case, get all nodes - node_results = await self.db.query( - SQL_TEMPLATES["get_all_nodes"], - {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, - multirows=True, - ) - else: - # Get nodes matching the label - label_pattern = f"%{node_label}%" - node_results = await self.db.query( - SQL_TEMPLATES["get_matching_nodes"], - {"workspace": self.db.workspace, "label_pattern": label_pattern}, - multirows=True, - ) - - if not node_results: - logger.warning(f"No nodes found matching label {node_label}") - return result - - # Limit the number of returned nodes - if len(node_results) > MAX_GRAPH_NODES: - node_results = node_results[:MAX_GRAPH_NODES] - - # Extract node names for edge query - node_names = [node["name"] for node in node_results] - node_names_str = ",".join([f"'{name}'" for name in node_names]) - - # Add nodes to result - for node in node_results: - node_properties = { - k: v for k, v in node.items() if k not in ["id", "name", "entity_type"] - } - result.nodes.append( - KnowledgeGraphNode( - id=node["name"], - labels=[node["entity_type"]] - if node.get("entity_type") - else [node["name"]], - properties=node_properties, - ) - ) - - # Get related edges - edge_results = await self.db.query( - SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), - {"workspace": self.db.workspace}, - multirows=True, - ) - - if edge_results: - # Add edges to result - for edge in edge_results: - # Only include edges related to selected nodes - if ( - edge["source_name"] in node_names - and edge["target_name"] in node_names - ): - edge_id = f"{edge['source_name']}-{edge['target_name']}" - edge_properties = { - k: v - for k, v in edge.items() - if k not in ["id", "source_name", "target_name"] - } - - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="RELATED", - source=edge["source_name"], - target=edge["target_name"], - properties=edge_properties, - ) - ) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - return result - - async def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node IDs to delete - """ - for node_id in nodes: - await self.delete_node(node_id) - - async def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to delete, each edge is a (source, target) tuple - """ - for source, target in edges: - await self.db.execute( - SQL_TEMPLATES["remove_multiple_edges"], - {"source": source, "target": target, "workspace": self.db.workspace}, - ) - - -N_T = { - NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", - NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES", - NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", -} -N_ID = { - NameSpace.KV_STORE_FULL_DOCS: "doc_id", - NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id", - NameSpace.VECTOR_STORE_CHUNKS: "chunk_id", - NameSpace.VECTOR_STORE_ENTITIES: "entity_id", - NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id", -} - - -def namespace_to_table_name(namespace: str) -> str: - for k, v in N_T.items(): - if is_namespace(namespace, k): - return v - - -def namespace_to_id(namespace: str) -> str: - for k, v in N_ID.items(): - if is_namespace(namespace, k): - return v - - -TABLES = { - "LIGHTRAG_DOC_FULL": { - "ddl": """ - CREATE TABLE LIGHTRAG_DOC_FULL ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `doc_id` VARCHAR(256) NOT NULL, - `workspace` varchar(1024), - `content` LONGTEXT, - `meta` JSON, - `createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - `updatetime` TIMESTAMP DEFAULT NULL, - UNIQUE KEY (`doc_id`) - ); - """ - }, - "LIGHTRAG_DOC_CHUNKS": { - "ddl": """ - CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `chunk_id` VARCHAR(256) NOT NULL, - `full_doc_id` VARCHAR(256) NOT NULL, - `workspace` varchar(1024), - `chunk_order_index` INT, - `tokens` INT, - `content` LONGTEXT, - `content_vector` VECTOR, - `createtime` TIMESTAMP, - `updatetime` TIMESTAMP, - UNIQUE KEY (`chunk_id`) - ); - """ - }, - "LIGHTRAG_GRAPH_NODES": { - "ddl": """ - CREATE TABLE LIGHTRAG_GRAPH_NODES ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `entity_id` VARCHAR(256), - `workspace` varchar(1024), - `name` VARCHAR(2048), - `entity_type` VARCHAR(1024), - `description` LONGTEXT, - `source_chunk_id` VARCHAR(256), - `content` LONGTEXT, - `content_vector` VECTOR, - `createtime` TIMESTAMP, - `updatetime` TIMESTAMP, - KEY (`entity_id`) - ); - """ - }, - "LIGHTRAG_GRAPH_EDGES": { - "ddl": """ - CREATE TABLE LIGHTRAG_GRAPH_EDGES ( - `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `relation_id` VARCHAR(256), - `workspace` varchar(1024), - `source_name` VARCHAR(2048), - `target_name` VARCHAR(2048), - `weight` DECIMAL, - `keywords` TEXT, - `description` LONGTEXT, - `source_chunk_id` varchar(256), - `content` LONGTEXT, - `content_vector` VECTOR, - `createtime` TIMESTAMP, - `updatetime` TIMESTAMP, - KEY (`relation_id`) - ); - """ - }, - "LIGHTRAG_LLM_CACHE": { - "ddl": """ - CREATE TABLE LIGHTRAG_LLM_CACHE ( - id BIGINT PRIMARY KEY AUTO_INCREMENT, - send TEXT, - return TEXT, - model VARCHAR(1024), - createtime DATETIME DEFAULT CURRENT_TIMESTAMP, - updatetime DATETIME DEFAULT NULL - ); - """ - }, -} - - -SQL_TEMPLATES = { - # SQL for KVStorage - "get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace", - "get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace", - "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace", - "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace", - "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace", - # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE) - "upsert_doc_full": """ - INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) - VALUES (:id, :content, :workspace) - ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP - """, - "upsert_chunk": """ - INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace, createtime, updatetime) - VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) - ON DUPLICATE KEY UPDATE - content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), - full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = FROM_UNIXTIME(:timestamp) - """, - # SQL for VectorStorage - "entities": """SELECT n.name as entity_name, UNIX_TIMESTAMP(n.createtime) as created_at FROM - (SELECT entity_id as id, name, createtime, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance - FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n - WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k - """, - "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id, UNIX_TIMESTAMP(e.createtime) as created_at FROM - (SELECT source_name, target_name, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance - FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e - WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k - """, - "chunks": """SELECT c.id, UNIX_TIMESTAMP(c.createtime) as created_at FROM - (SELECT chunk_id as id, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c - WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k - """, - "has_entity": """ - SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace - """, - "has_relationship": """ - SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace - """, - "upsert_entity": """ - INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace, createtime, updatetime) - VALUES(:id, :name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) - ON DUPLICATE KEY UPDATE - content = VALUES(content), - content_vector = VALUES(content_vector), - updatetime = FROM_UNIXTIME(:timestamp) - """, - "upsert_relationship": """ - INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace, createtime, updatetime) - VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) - ON DUPLICATE KEY UPDATE - content = VALUES(content), - content_vector = VALUES(content_vector), - updatetime = FROM_UNIXTIME(:timestamp) - """, - # SQL for GraphStorage - "get_node": """ - SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector - FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace - """, - "get_edge": """ - SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector - FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace - """, - "get_node_edges": """ - SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector - FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace - """, - "node_degree": """ - SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name) - """, - "upsert_node": """ - INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description) - VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description) - ON DUPLICATE KEY UPDATE - name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), - workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, - source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description) - """, - "upsert_edge": """ - INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector, - workspace, weight, keywords, description, source_chunk_id) - VALUES(:source_name, :target_name, :content, :content_vector, - :workspace, :weight, :keywords, :description, :source_chunk_id) - ON DUPLICATE KEY UPDATE - source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), - content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, - weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), - source_chunk_id = VALUES(source_chunk_id) - """, - "delete_node": """ - DELETE FROM LIGHTRAG_GRAPH_NODES - WHERE name = :name AND workspace = :workspace - """, - "delete_node_edges": """ - DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace - """, - "get_all_labels": """ - SELECT DISTINCT entity_type as label - FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace - ORDER BY entity_type - """, - "get_matching_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES - WHERE name LIKE :label_pattern AND workspace = :workspace - ORDER BY name - """, - "get_all_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace - ORDER BY name - LIMIT :max_nodes - """, - "get_related_edges": """ - SELECT * FROM LIGHTRAG_GRAPH_EDGES - WHERE (source_name IN (:node_names) OR target_name IN (:node_names)) - AND workspace = :workspace - """, - "remove_multiple_edges": """ - DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE (source_name = :source AND target_name = :target) - AND workspace = :workspace - """, - # Drop tables - "drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace", -}