diff --git a/lightrag/base.py b/lightrag/base.py index 16156bfb..81b71ea1 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -19,7 +19,6 @@ from typing import ( from .utils import EmbeddingFunc from .types import KnowledgeGraph from .constants import ( - GRAPH_FIELD_SEP, DEFAULT_TOP_K, DEFAULT_CHUNK_TOP_K, DEFAULT_MAX_ENTITY_TOKENS, @@ -551,56 +550,6 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[node_id] = edges if edges is not None else [] return result - @abstractmethod - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all nodes that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. - - Returns: - list[dict]: A list of nodes, where each node is a dictionary of its properties. - An empty list if no matching nodes are found. - """ - - @abstractmethod - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all edges that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated edges for. - - Returns: - list[dict]: A list of edges, where each edge is a dictionary of its properties. - An empty list if no matching edges are found. - """ - # Default implementation iterates through all nodes and their edges, which is inefficient. - # This method should be overridden by subclasses for better performance. - all_edges = [] - all_labels = await self.get_all_labels() - processed_edges = set() - - for label in all_labels: - edges = await self.get_node_edges(label) - if edges: - for src_id, tgt_id in edges: - # Avoid processing the same edge twice in an undirected graph - edge_tuple = tuple(sorted((src_id, tgt_id))) - if edge_tuple in processed_edges: - continue - processed_edges.add(edge_tuple) - - edge = await self.get_edge(src_id, tgt_id) - if edge and "source_id" in edge: - source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP)) - if not source_ids.isdisjoint(chunk_ids): - # Add source and target to the edge dict for easier processing later - edge_with_nodes = edge.copy() - edge_with_nodes["source"] = src_id - edge_with_nodes["target"] = tgt_id - all_edges.append(edge_with_nodes) - return all_edges - @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Insert a new node or update an existing node in the graph. diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 6fd6841c..d81c2ebd 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -8,7 +8,7 @@ import configparser from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..kg.shared_storage import get_data_init_lock +from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -101,9 +101,10 @@ class MemgraphStorage(BaseGraphStorage): raise async def finalize(self): - if self._driver is not None: - await self._driver.close() - self._driver = None + async with get_graph_db_lock(): + if self._driver is not None: + await self._driver.close() + self._driver = None async def __aexit__(self, exc_type, exc, tb): await self.finalize() @@ -132,7 +133,6 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - result = None try: workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" @@ -146,10 +146,7 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" ) - if result is not None: - await ( - result.consume() - ) # Ensure the result is consumed even on error + await result.consume() # Ensure the result is consumed even on error raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: @@ -173,7 +170,6 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - result = None try: workspace_label = self._get_workspace_label() query = ( @@ -194,10 +190,7 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) - if result is not None: - await ( - result.consume() - ) # Ensure the result is consumed even on error + await result.consume() # Ensure the result is consumed even on error raise async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -319,7 +312,6 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - result = None try: workspace_label = self._get_workspace_label() query = f""" @@ -336,10 +328,7 @@ class MemgraphStorage(BaseGraphStorage): return labels except Exception as e: logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") - if result is not None: - await ( - result.consume() - ) # Ensure the result is consumed even on error + await result.consume() # Ensure the result is consumed even on error raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: @@ -363,7 +352,6 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - results = None try: workspace_label = self._get_workspace_label() query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) @@ -401,10 +389,7 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) - if results is not None: - await ( - results.consume() - ) # Ensure results are consumed even on error + await results.consume() # Ensure results are consumed even on error raise except Exception as e: logger.error( @@ -434,7 +419,6 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - result = None try: workspace_label = self._get_workspace_label() query = f""" @@ -467,10 +451,7 @@ class MemgraphStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" ) - if result is not None: - await ( - result.consume() - ) # Ensure the result is consumed even on error + await result.consume() # Ensure the result is consumed even on error raise async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: @@ -761,21 +742,22 @@ class MemgraphStorage(BaseGraphStorage): raise RuntimeError( "Memgraph driver is not initialized. Call 'await initialize()' first." ) - try: - async with self._driver.session(database=self._DATABASE) as session: - workspace_label = self._get_workspace_label() - query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" - result = await session.run(query) - await result.consume() - logger.info( - f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + async with get_graph_db_lock(): + try: + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info( + f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + ) + return {"status": "success", "message": "workspace data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" ) - return {"status": "success", "message": "workspace data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" - ) - return {"status": "error", "message": str(e)} + return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes. @@ -1048,7 +1030,6 @@ class MemgraphStorage(BaseGraphStorage): "Memgraph driver is not initialized. Call 'await initialize()' first." ) - result = None try: workspace_label = self._get_workspace_label() async with self._driver.session( @@ -1075,8 +1056,6 @@ class MemgraphStorage(BaseGraphStorage): return labels except Exception as e: logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}") - if result is not None: - await result.consume() return [] async def search_labels(self, query: str, limit: int = 50) -> list[str]: @@ -1099,7 +1078,6 @@ class MemgraphStorage(BaseGraphStorage): if not query_lower: return [] - result = None try: workspace_label = self._get_workspace_label() async with self._driver.session( @@ -1133,6 +1111,4 @@ class MemgraphStorage(BaseGraphStorage): return labels except Exception as e: logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") - if result is not None: - await result.consume() return [] diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 138e98b1..e11e6411 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1031,45 +1031,6 @@ class MongoGraphStorage(BaseGraphStorage): return result - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all nodes that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. - - Returns: - list[dict]: A list of nodes, where each node is a dictionary of its properties. - An empty list if no matching nodes are found. - """ - if not chunk_ids: - return [] - - cursor = self.collection.find({"source_ids": {"$in": chunk_ids}}) - return [doc async for doc in cursor] - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """Get all edges that are associated with the given chunk_ids. - - Args: - chunk_ids (list[str]): A list of chunk IDs to find associated edges for. - - Returns: - list[dict]: A list of edges, where each edge is a dictionary of its properties. - An empty list if no matching edges are found. - """ - if not chunk_ids: - return [] - - cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}}) - - edges = [] - async for edge in cursor: - edge["source"] = edge["source_node_id"] - edge["target"] = edge["target_node_id"] - edges.append(edge) - - return edges - # # ------------------------------------------------------------------------- # UPSERTS diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index d3d6c4eb..76fa11f2 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -16,7 +16,7 @@ import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..kg.shared_storage import get_data_init_lock +from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -340,9 +340,10 @@ class Neo4JStorage(BaseGraphStorage): async def finalize(self): """Close the Neo4j driver and release all resources""" - if self._driver: - await self._driver.close() - self._driver = None + async with get_graph_db_lock(): + if self._driver: + await self._driver.close() + self._driver = None async def __aexit__(self, exc_type, exc, tb): """Ensure driver is closed when context manager exits""" @@ -352,20 +353,6 @@ class Neo4JStorage(BaseGraphStorage): # Neo4J handles persistence automatically pass - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def has_node(self, node_id: str) -> bool: """ Check if a node with the given label exists in the database @@ -384,7 +371,6 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - result = None try: query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) @@ -395,24 +381,9 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" ) - if result is not None: - await result.consume() # Ensure results are consumed even on error + await result.consume() # Ensure results are consumed even on error raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if an edge exists between two nodes @@ -432,7 +403,6 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - result = None try: query = ( f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " @@ -450,24 +420,9 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) - if result is not None: - await result.consume() # Ensure results are consumed even on error + await result.consume() # Ensure results are consumed even on error raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties @@ -521,20 +476,6 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. @@ -571,20 +512,6 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Make sure to consume the result fully return nodes - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. @@ -633,20 +560,6 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. @@ -731,20 +644,6 @@ class Neo4JStorage(BaseGraphStorage): edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) return edge_degrees - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -832,20 +731,6 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def get_edges_batch( self, pairs: list[dict[str, str]] ) -> dict[tuple[str, str], dict]: @@ -896,20 +781,6 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() return edges_dict - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node identified by its label. @@ -928,7 +799,6 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - results = None try: workspace_label = self._get_workspace_label() query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) @@ -966,10 +836,7 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) - if results is not None: - await ( - results.consume() - ) # Ensure results are consumed even on error + await results.consume() # Ensure results are consumed even on error raise except Exception as e: logger.error( @@ -977,20 +844,6 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: @@ -1739,7 +1592,6 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - result = None try: query = f""" MATCH (n:`{workspace_label}`) @@ -1764,8 +1616,7 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting popular labels: {str(e)}" ) - if result is not None: - await result.consume() + await result.consume() raise async def search_labels(self, query: str, limit: int = 50) -> list[str]: @@ -1912,23 +1763,24 @@ class Neo4JStorage(BaseGraphStorage): - On success: {"status": "success", "message": "workspace data dropped"} - On failure: {"status": "error", "message": ""} """ - workspace_label = self._get_workspace_label() - try: - async with self._driver.session(database=self._DATABASE) as session: - # Delete all nodes and relationships in current workspace only - query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" - result = await session.run(query) - await result.consume() # Ensure result is fully consumed + async with get_graph_db_lock(): + workspace_label = self._get_workspace_label() + try: + async with self._driver.session(database=self._DATABASE) as session: + # Delete all nodes and relationships in current workspace only + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() # Ensure result is fully consumed - # logger.debug( - # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" - # ) - return { - "status": "success", - "message": f"workspace '{workspace_label}' data dropped", - } - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" - ) - return {"status": "error", "message": str(e)} + # logger.debug( + # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" + # ) + return { + "status": "success", + "message": f"workspace '{workspace_label}' data dropped", + } + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 38d66a57..2a7c6158 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -33,7 +33,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger -from ..kg.shared_storage import get_data_init_lock +from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock import pipmaster as pm @@ -77,9 +77,6 @@ class PostgreSQLDB: self.hnsw_m = config.get("hnsw_m") self.hnsw_ef = config.get("hnsw_ef") self.ivfflat_lists = config.get("ivfflat_lists") - self.vchordrq_build_options = config.get("vchordrq_build_options") - self.vchordrq_probes = config.get("vchordrq_probes") - self.vchordrq_epsilon = config.get("vchordrq_epsilon") # Server settings self.server_settings = config.get("server_settings") @@ -365,8 +362,7 @@ class PostgreSQLDB: await self.configure_age(connection, graph_name) elif with_age and not graph_name: raise ValueError("Graph name is required when with_age is True") - if self.vector_index_type == "VCHORDRQ": - await self.configure_vchordrq(connection) + return await operation(connection) @staticmethod @@ -383,7 +379,7 @@ class PostgreSQLDB: async def configure_age_extension(connection: asyncpg.Connection) -> None: """Create AGE extension if it doesn't exist for graph operations.""" try: - await connection.execute("CREATE EXTENSION IF NOT EXISTS AGE CASCADE") # type: ignore + await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore logger.info("PostgreSQL, AGE extension enabled") except Exception as e: logger.warning(f"Could not create AGE extension: {e}") @@ -412,14 +408,6 @@ class PostgreSQLDB: ): pass - async def configure_vchordrq(self, connection: asyncpg.Connection) -> None: - """Configure VCHORDRQ extension for vector similarity search.""" - try: - await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'") - await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}") - except Exception as e: - logger.error(f"Failed to set vchordrq.probes or vchordrq.epsilon: {e}") - async def _migrate_llm_cache_schema(self): """Migrate LLM cache schema: add new columns and remove deprecated mode field""" try: @@ -1154,12 +1142,19 @@ class PostgreSQLDB: f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}" ) try: - if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]: - await self._create_vector_indexes() + if self.vector_index_type == "HNSW": + await self._create_hnsw_vector_indexes() + elif self.vector_index_type == "IVFFLAT": + await self._create_ivfflat_vector_indexes() + elif self.vector_index_type == "FLAT": + logger.warning( + "FLAT index type is not supported by pgvector. Skipping vector index creation. " + "Please use 'HNSW' or 'IVFFLAT' instead." + ) else: logger.warning( "Doesn't support this vector index type: {self.vector_index_type}. " - "Supported types: HNSW, IVFFLAT, VCHORDRQ" + "Supported types: HNSW, IVFFLAT" ) except Exception as e: logger.error( @@ -1366,39 +1361,21 @@ class PostgreSQLDB: except Exception as e: logger.warning(f"Failed to create index {index['name']}: {e}") - async def _create_vector_indexes(self): + async def _create_hnsw_vector_indexes(self): vdb_tables = [ "LIGHTRAG_VDB_CHUNKS", "LIGHTRAG_VDB_ENTITY", "LIGHTRAG_VDB_RELATION", ] - create_sql = { - "HNSW": f""" - CREATE INDEX {{vector_index_name}} - ON {{k}} USING hnsw (content_vector vector_cosine_ops) - WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) - """, - "IVFFLAT": f""" - CREATE INDEX {{vector_index_name}} - ON {{k}} USING ivfflat (content_vector vector_cosine_ops) - WITH (lists = {self.ivfflat_lists}) - """, - "VCHORDRQ": f""" - CREATE INDEX {{vector_index_name}} - ON {{k}} USING vchordrq (content_vector vector_cosine_ops) - {f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''} - """, - } - embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) + for k in vdb_tables: - vector_index_name = ( - f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine" - ) + vector_index_name = f"idx_{k.lower()}_hnsw_cosine" check_vector_index_sql = f""" SELECT 1 FROM pg_indexes - WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' + WHERE indexname = '{vector_index_name}' + AND tablename = '{k.lower()}' """ try: vector_index_exists = await self.query(check_vector_index_sql) @@ -1407,24 +1384,64 @@ class PostgreSQLDB: alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" await self.execute(alter_sql) logger.debug(f"Ensured vector dimension for {k}") - logger.info( - f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" - ) - await self.execute( - create_sql[self.vector_index_type].format( - vector_index_name=vector_index_name, k=k - ) - ) + + create_vector_index_sql = f""" + CREATE INDEX {vector_index_name} + ON {k} USING hnsw (content_vector vector_cosine_ops) + WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) + """ + logger.info(f"Creating hnsw index {vector_index_name} on table {k}") + await self.execute(create_vector_index_sql) logger.info( f"Successfully created vector index {vector_index_name} on table {k}" ) else: logger.info( - f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" + f"HNSW vector index {vector_index_name} already exists on table {k}" ) except Exception as e: logger.error(f"Failed to create vector index on table {k}, Got: {e}") + async def _create_ivfflat_vector_indexes(self): + vdb_tables = [ + "LIGHTRAG_VDB_CHUNKS", + "LIGHTRAG_VDB_ENTITY", + "LIGHTRAG_VDB_RELATION", + ] + + embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) + + for k in vdb_tables: + index_name = f"idx_{k.lower()}_ivfflat_cosine" + check_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{index_name}' AND tablename = '{k.lower()}' + """ + try: + exists = await self.query(check_index_sql) + if not exists: + # Only set vector dimension when index doesn't exist + alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" + await self.execute(alter_sql) + logger.debug(f"Ensured vector dimension for {k}") + + create_sql = f""" + CREATE INDEX {index_name} + ON {k} USING ivfflat (content_vector vector_cosine_ops) + WITH (lists = {self.ivfflat_lists}) + """ + logger.info(f"Creating ivfflat index {index_name} on table {k}") + await self.execute(create_sql) + logger.info( + f"Successfully created ivfflat index {index_name} on table {k}" + ) + else: + logger.info( + f"Ivfflat vector index {index_name} already exists on table {k}" + ) + except Exception as e: + logger.error(f"Failed to create ivfflat index on {k}: {e}") + async def query( self, sql: str, @@ -1579,20 +1596,6 @@ class ClientManager: config.get("postgres", "ivfflat_lists", fallback="100"), ) ), - "vchordrq_build_options": os.environ.get( - "POSTGRES_VCHORDRQ_BUILD_OPTIONS", - config.get("postgres", "vchordrq_build_options", fallback=""), - ), - "vchordrq_probes": os.environ.get( - "POSTGRES_VCHORDRQ_PROBES", - config.get("postgres", "vchordrq_probes", fallback=""), - ), - "vchordrq_epsilon": float( - os.environ.get( - "POSTGRES_VCHORDRQ_EPSILON", - config.get("postgres", "vchordrq_epsilon", fallback="1.9"), - ) - ), # Server settings for Supabase "server_settings": os.environ.get( "POSTGRES_SERVER_SETTINGS", @@ -1699,9 +1702,10 @@ class PGKVStorage(BaseKVStorage): self.workspace = "default" async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_storage_lock(): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -2143,21 +2147,22 @@ class PGKVStorage(BaseKVStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + async with get_storage_lock(): + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -2192,9 +2197,10 @@ class PGVectorStorage(BaseVectorStorage): self.workspace = "default" async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_storage_lock(): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None def _upsert_chunks( self, item: dict[str, Any], current_time: datetime.datetime @@ -2530,21 +2536,22 @@ class PGVectorStorage(BaseVectorStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + async with get_storage_lock(): + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -2579,9 +2586,10 @@ class PGDocStatusStorage(DocStatusStorage): self.workspace = "default" async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_storage_lock(): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" @@ -3156,21 +3164,22 @@ class PGDocStatusStorage(DocStatusStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + async with get_storage_lock(): + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} class PGGraphQueryException(Exception): @@ -3302,9 +3311,10 @@ class PGGraphStorage(BaseGraphStorage): ) async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_graph_db_lock(): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def index_done_callback(self) -> None: # PG handles persistence automatically @@ -3558,13 +3568,17 @@ class PGGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties""" - result = await self.get_nodes_batch(node_ids=[node_id]) + label = self._normalize_node_id(node_id) + + result = await self.get_nodes_batch(node_ids=[label]) if result and node_id in result: return result[node_id] return None async def node_degree(self, node_id: str) -> int: - result = await self.node_degrees_batch(node_ids=[node_id]) + label = self._normalize_node_id(node_id) + + result = await self.node_degrees_batch(node_ids=[label]) if result and node_id in result: return result[node_id] @@ -3577,11 +3591,12 @@ class PGGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get edge properties between two nodes""" - result = await self.get_edges_batch( - [{"src": source_node_id, "tgt": target_node_id}] - ) - if result and (source_node_id, target_node_id) in result: - return result[(source_node_id, target_node_id)] + src_label = self._normalize_node_id(source_node_id) + tgt_label = self._normalize_node_id(target_node_id) + + result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}]) + if result and (src_label, tgt_label) in result: + return result[(src_label, tgt_label)] return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: @@ -3779,17 +3794,13 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - seen: set[str] = set() - unique_ids: list[str] = [] - lookup: dict[str, str] = {} - requested: set[str] = set() + seen = set() + unique_ids = [] for nid in node_ids: - if nid not in seen: - seen.add(nid) - unique_ids.append(nid) - requested.add(nid) - lookup[nid] = nid - lookup[self._normalize_node_id(nid)] = nid + nid_norm = self._normalize_node_id(nid) + if nid_norm not in seen: + seen.add(nid_norm) + unique_ids.append(nid_norm) # Build result dictionary nodes_dict = {} @@ -3828,18 +3839,10 @@ class PGGraphStorage(BaseGraphStorage): node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" + f"Failed to parse node string in batch: {node_dict}" ) - node_key = result["node_id"] - original_key = lookup.get(node_key) - if original_key is None: - logger.warning( - f"[{self.workspace}] Node {node_key} not found in lookup map" - ) - original_key = node_key - if original_key in requested: - nodes_dict[original_key] = node_dict + nodes_dict[result["node_id"]] = node_dict return nodes_dict @@ -3862,17 +3865,13 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - seen: set[str] = set() + seen = set() unique_ids: list[str] = [] - lookup: dict[str, str] = {} - requested: set[str] = set() for nid in node_ids: - if nid not in seen: - seen.add(nid) - unique_ids.append(nid) - requested.add(nid) - lookup[nid] = nid - lookup[self._normalize_node_id(nid)] = nid + n = self._normalize_node_id(nid) + if n not in seen: + seen.add(n) + unique_ids.append(n) out_degrees = {} in_degrees = {} @@ -3924,16 +3923,8 @@ class PGGraphStorage(BaseGraphStorage): node_id = row["node_id"] if not node_id: continue - node_key = node_id - original_key = lookup.get(node_key) - if original_key is None: - logger.warning( - f"[{self.workspace}] Node {node_key} not found in lookup map" - ) - original_key = node_key - if original_key in requested: - out_degrees[original_key] = int(row.get("out_degree", 0) or 0) - in_degrees[original_key] = int(row.get("in_degree", 0) or 0) + out_degrees[node_id] = int(row.get("out_degree", 0) or 0) + in_degrees[node_id] = int(row.get("in_degree", 0) or 0) degrees_dict = {} for node_id in node_ids: @@ -4062,7 +4053,7 @@ class PGGraphStorage(BaseGraphStorage): edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( - f"[{self.workspace}]Failed to parse edge properties string: {edge_props}" + f"Failed to parse edge properties string: {edge_props}" ) continue @@ -4078,7 +4069,7 @@ class PGGraphStorage(BaseGraphStorage): edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( - f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" + f"Failed to parse edge properties string: {edge_props}" ) continue @@ -4704,20 +4695,21 @@ class PGGraphStorage(BaseGraphStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n) - DETACH DELETE n - $$) AS (result agtype)""" + async with get_graph_db_lock(): + try: + drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n) + DETACH DELETE n + $$) AS (result agtype)""" - await self._query(drop_query, readonly=False) - return { - "status": "success", - "message": f"workspace '{self.workspace}' graph data dropped", - } - except Exception as e: - logger.error(f"[{self.workspace}] Error dropping graph: {e}") - return {"status": "error", "message": str(e)} + await self._query(drop_query, readonly=False) + return { + "status": "success", + "message": f"workspace '{self.workspace}' graph data dropped", + } + except Exception as e: + logger.error(f"[{self.workspace}] Error dropping graph: {e}") + return {"status": "error", "message": str(e)} # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before