Remove unused chunk-based node/edge retrieval methods

(cherry picked from commit 807d2461d3)
This commit is contained in:
yangdx 2025-11-06 18:17:10 +08:00 committed by Raphaël MANSUY
parent ce702ccb2f
commit 211dbc3f78
5 changed files with 225 additions and 495 deletions

View file

@ -19,7 +19,6 @@ from typing import (
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
from .constants import (
GRAPH_FIELD_SEP,
DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K,
DEFAULT_MAX_ENTITY_TOKENS,
@ -551,56 +550,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else []
return result
@abstractmethod
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
@abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
# Default implementation iterates through all nodes and their edges, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_edges = []
all_labels = await self.get_all_labels()
processed_edges = set()
for label in all_labels:
edges = await self.get_node_edges(label)
if edges:
for src_id, tgt_id in edges:
# Avoid processing the same edge twice in an undirected graph
edge_tuple = tuple(sorted((src_id, tgt_id)))
if edge_tuple in processed_edges:
continue
processed_edges.add(edge_tuple)
edge = await self.get_edge(src_id, tgt_id)
if edge and "source_id" in edge:
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
# Add source and target to the edge dict for easier processing later
edge_with_nodes = edge.copy()
edge_with_nodes["source"] = src_id
edge_with_nodes["target"] = tgt_id
all_edges.append(edge_with_nodes)
return all_edges
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.

View file

@ -8,7 +8,7 @@ import configparser
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm
if not pm.is_installed("neo4j"):
@ -101,9 +101,10 @@ class MemgraphStorage(BaseGraphStorage):
raise
async def finalize(self):
if self._driver is not None:
await self._driver.close()
self._driver = None
async with get_graph_db_lock():
if self._driver is not None:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
await self.finalize()
@ -132,7 +133,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
@ -146,10 +146,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
await result.consume() # Ensure the result is consumed even on error
raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
@ -173,7 +170,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = (
@ -194,10 +190,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
await result.consume() # Ensure the result is consumed even on error
raise
async def get_node(self, node_id: str) -> dict[str, str] | None:
@ -319,7 +312,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = f"""
@ -336,10 +328,7 @@ class MemgraphStorage(BaseGraphStorage):
return labels
except Exception as e:
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
await result.consume() # Ensure the result is consumed even on error
raise
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -363,7 +352,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = None
try:
workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -401,10 +389,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
)
if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(
@ -434,7 +419,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = f"""
@ -467,10 +451,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
)
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
await result.consume() # Ensure the result is consumed even on error
raise
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
@ -761,21 +742,22 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query)
await result.consume()
logger.info(
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
async with get_graph_db_lock():
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query)
await result.consume()
logger.info(
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
@ -1048,7 +1030,6 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
result = None
try:
workspace_label = self._get_workspace_label()
async with self._driver.session(
@ -1075,8 +1056,6 @@ class MemgraphStorage(BaseGraphStorage):
return labels
except Exception as e:
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
if result is not None:
await result.consume()
return []
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
@ -1099,7 +1078,6 @@ class MemgraphStorage(BaseGraphStorage):
if not query_lower:
return []
result = None
try:
workspace_label = self._get_workspace_label()
async with self._driver.session(
@ -1133,6 +1111,4 @@ class MemgraphStorage(BaseGraphStorage):
return labels
except Exception as e:
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
if result is not None:
await result.consume()
return []

View file

@ -1031,45 +1031,6 @@ class MongoGraphStorage(BaseGraphStorage):
return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if not chunk_ids:
return []
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
return [doc async for doc in cursor]
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if not chunk_ids:
return []
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
edges = []
async for edge in cursor:
edge["source"] = edge["source_node_id"]
edge["target"] = edge["target_node_id"]
edges.append(edge)
return edges
#
# -------------------------------------------------------------------------
# UPSERTS

View file

@ -16,7 +16,7 @@ import logging
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm
if not pm.is_installed("neo4j"):
@ -340,9 +340,10 @@ class Neo4JStorage(BaseGraphStorage):
async def finalize(self):
"""Close the Neo4j driver and release all resources"""
if self._driver:
await self._driver.close()
self._driver = None
async with get_graph_db_lock():
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits"""
@ -352,20 +353,6 @@ class Neo4JStorage(BaseGraphStorage):
# Neo4J handles persistence automatically
pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_node(self, node_id: str) -> bool:
"""
Check if a node with the given label exists in the database
@ -384,7 +371,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id)
@ -395,24 +381,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
if result is not None:
await result.consume() # Ensure results are consumed even on error
await result.consume() # Ensure results are consumed even on error
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes
@ -432,7 +403,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
query = (
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
@ -450,24 +420,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
if result is not None:
await result.consume() # Ensure results are consumed even on error
await result.consume() # Ensure results are consumed even on error
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties
@ -521,20 +476,6 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
@ -571,20 +512,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Make sure to consume the result fully
return nodes
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
@ -633,20 +560,6 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
@ -731,20 +644,6 @@ class Neo4JStorage(BaseGraphStorage):
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
@ -832,20 +731,6 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
@ -896,20 +781,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume()
return edges_dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
@ -928,7 +799,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = None
try:
workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -966,10 +836,7 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
)
if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(
@ -977,20 +844,6 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
@ -1739,7 +1592,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
query = f"""
MATCH (n:`{workspace_label}`)
@ -1764,8 +1616,7 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting popular labels: {str(e)}"
)
if result is not None:
await result.consume()
await result.consume()
raise
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
@ -1912,23 +1763,24 @@ class Neo4JStorage(BaseGraphStorage):
- On success: {"status": "success", "message": "workspace data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
workspace_label = self._get_workspace_label()
try:
async with self._driver.session(database=self._DATABASE) as session:
# Delete all nodes and relationships in current workspace only
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
async with get_graph_db_lock():
workspace_label = self._get_workspace_label()
try:
async with self._driver.session(database=self._DATABASE) as session:
# Delete all nodes and relationships in current workspace only
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
# logger.debug(
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
# )
return {
"status": "success",
"message": f"workspace '{workspace_label}' data dropped",
}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
# logger.debug(
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
# )
return {
"status": "success",
"message": f"workspace '{workspace_label}' data dropped",
}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -33,7 +33,7 @@ from ..base import (
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..kg.shared_storage import get_data_init_lock
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
import pipmaster as pm
@ -77,9 +77,6 @@ class PostgreSQLDB:
self.hnsw_m = config.get("hnsw_m")
self.hnsw_ef = config.get("hnsw_ef")
self.ivfflat_lists = config.get("ivfflat_lists")
self.vchordrq_build_options = config.get("vchordrq_build_options")
self.vchordrq_probes = config.get("vchordrq_probes")
self.vchordrq_epsilon = config.get("vchordrq_epsilon")
# Server settings
self.server_settings = config.get("server_settings")
@ -365,8 +362,7 @@ class PostgreSQLDB:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
if self.vector_index_type == "VCHORDRQ":
await self.configure_vchordrq(connection)
return await operation(connection)
@staticmethod
@ -383,7 +379,7 @@ class PostgreSQLDB:
async def configure_age_extension(connection: asyncpg.Connection) -> None:
"""Create AGE extension if it doesn't exist for graph operations."""
try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS AGE CASCADE") # type: ignore
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
logger.info("PostgreSQL, AGE extension enabled")
except Exception as e:
logger.warning(f"Could not create AGE extension: {e}")
@ -412,14 +408,6 @@ class PostgreSQLDB:
):
pass
async def configure_vchordrq(self, connection: asyncpg.Connection) -> None:
"""Configure VCHORDRQ extension for vector similarity search."""
try:
await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'")
await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}")
except Exception as e:
logger.error(f"Failed to set vchordrq.probes or vchordrq.epsilon: {e}")
async def _migrate_llm_cache_schema(self):
"""Migrate LLM cache schema: add new columns and remove deprecated mode field"""
try:
@ -1154,12 +1142,19 @@ class PostgreSQLDB:
f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}"
)
try:
if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]:
await self._create_vector_indexes()
if self.vector_index_type == "HNSW":
await self._create_hnsw_vector_indexes()
elif self.vector_index_type == "IVFFLAT":
await self._create_ivfflat_vector_indexes()
elif self.vector_index_type == "FLAT":
logger.warning(
"FLAT index type is not supported by pgvector. Skipping vector index creation. "
"Please use 'HNSW' or 'IVFFLAT' instead."
)
else:
logger.warning(
"Doesn't support this vector index type: {self.vector_index_type}. "
"Supported types: HNSW, IVFFLAT, VCHORDRQ"
"Supported types: HNSW, IVFFLAT"
)
except Exception as e:
logger.error(
@ -1366,39 +1361,21 @@ class PostgreSQLDB:
except Exception as e:
logger.warning(f"Failed to create index {index['name']}: {e}")
async def _create_vector_indexes(self):
async def _create_hnsw_vector_indexes(self):
vdb_tables = [
"LIGHTRAG_VDB_CHUNKS",
"LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION",
]
create_sql = {
"HNSW": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING hnsw (content_vector vector_cosine_ops)
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
""",
"IVFFLAT": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
""",
"VCHORDRQ": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING vchordrq (content_vector vector_cosine_ops)
{f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''}
""",
}
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
for k in vdb_tables:
vector_index_name = (
f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
)
vector_index_name = f"idx_{k.lower()}_hnsw_cosine"
check_vector_index_sql = f"""
SELECT 1 FROM pg_indexes
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
WHERE indexname = '{vector_index_name}'
AND tablename = '{k.lower()}'
"""
try:
vector_index_exists = await self.query(check_vector_index_sql)
@ -1407,24 +1384,64 @@ class PostgreSQLDB:
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}")
logger.info(
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
)
await self.execute(
create_sql[self.vector_index_type].format(
vector_index_name=vector_index_name, k=k
)
)
create_vector_index_sql = f"""
CREATE INDEX {vector_index_name}
ON {k} USING hnsw (content_vector vector_cosine_ops)
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
"""
logger.info(f"Creating hnsw index {vector_index_name} on table {k}")
await self.execute(create_vector_index_sql)
logger.info(
f"Successfully created vector index {vector_index_name} on table {k}"
)
else:
logger.info(
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
f"HNSW vector index {vector_index_name} already exists on table {k}"
)
except Exception as e:
logger.error(f"Failed to create vector index on table {k}, Got: {e}")
async def _create_ivfflat_vector_indexes(self):
vdb_tables = [
"LIGHTRAG_VDB_CHUNKS",
"LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION",
]
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
for k in vdb_tables:
index_name = f"idx_{k.lower()}_ivfflat_cosine"
check_index_sql = f"""
SELECT 1 FROM pg_indexes
WHERE indexname = '{index_name}' AND tablename = '{k.lower()}'
"""
try:
exists = await self.query(check_index_sql)
if not exists:
# Only set vector dimension when index doesn't exist
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}")
create_sql = f"""
CREATE INDEX {index_name}
ON {k} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
"""
logger.info(f"Creating ivfflat index {index_name} on table {k}")
await self.execute(create_sql)
logger.info(
f"Successfully created ivfflat index {index_name} on table {k}"
)
else:
logger.info(
f"Ivfflat vector index {index_name} already exists on table {k}"
)
except Exception as e:
logger.error(f"Failed to create ivfflat index on {k}: {e}")
async def query(
self,
sql: str,
@ -1579,20 +1596,6 @@ class ClientManager:
config.get("postgres", "ivfflat_lists", fallback="100"),
)
),
"vchordrq_build_options": os.environ.get(
"POSTGRES_VCHORDRQ_BUILD_OPTIONS",
config.get("postgres", "vchordrq_build_options", fallback=""),
),
"vchordrq_probes": os.environ.get(
"POSTGRES_VCHORDRQ_PROBES",
config.get("postgres", "vchordrq_probes", fallback=""),
),
"vchordrq_epsilon": float(
os.environ.get(
"POSTGRES_VCHORDRQ_EPSILON",
config.get("postgres", "vchordrq_epsilon", fallback="1.9"),
)
),
# Server settings for Supabase
"server_settings": os.environ.get(
"POSTGRES_SERVER_SETTINGS",
@ -1699,9 +1702,10 @@ class PGKVStorage(BaseKVStorage):
self.workspace = "default"
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async with get_storage_lock():
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@ -2143,21 +2147,22 @@ class PGKVStorage(BaseKVStorage):
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
async with get_storage_lock():
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@final
@ -2192,9 +2197,10 @@ class PGVectorStorage(BaseVectorStorage):
self.workspace = "default"
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async with get_storage_lock():
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks(
self, item: dict[str, Any], current_time: datetime.datetime
@ -2530,21 +2536,22 @@ class PGVectorStorage(BaseVectorStorage):
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
async with get_storage_lock():
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@final
@ -2579,9 +2586,10 @@ class PGDocStatusStorage(DocStatusStorage):
self.workspace = "default"
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async with get_storage_lock():
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
@ -3156,21 +3164,22 @@ class PGDocStatusStorage(DocStatusStorage):
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
async with get_storage_lock():
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
class PGGraphQueryException(Exception):
@ -3302,9 +3311,10 @@ class PGGraphStorage(BaseGraphStorage):
)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async with get_graph_db_lock():
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None:
# PG handles persistence automatically
@ -3558,13 +3568,17 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties"""
result = await self.get_nodes_batch(node_ids=[node_id])
label = self._normalize_node_id(node_id)
result = await self.get_nodes_batch(node_ids=[label])
if result and node_id in result:
return result[node_id]
return None
async def node_degree(self, node_id: str) -> int:
result = await self.node_degrees_batch(node_ids=[node_id])
label = self._normalize_node_id(node_id)
result = await self.node_degrees_batch(node_ids=[label])
if result and node_id in result:
return result[node_id]
@ -3577,11 +3591,12 @@ class PGGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes"""
result = await self.get_edges_batch(
[{"src": source_node_id, "tgt": target_node_id}]
)
if result and (source_node_id, target_node_id) in result:
return result[(source_node_id, target_node_id)]
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
if result and (src_label, tgt_label) in result:
return result[(src_label, tgt_label)]
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -3779,17 +3794,13 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
seen: set[str] = set()
unique_ids: list[str] = []
lookup: dict[str, str] = {}
requested: set[str] = set()
seen = set()
unique_ids = []
for nid in node_ids:
if nid not in seen:
seen.add(nid)
unique_ids.append(nid)
requested.add(nid)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
nid_norm = self._normalize_node_id(nid)
if nid_norm not in seen:
seen.add(nid_norm)
unique_ids.append(nid_norm)
# Build result dictionary
nodes_dict = {}
@ -3828,18 +3839,10 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
f"Failed to parse node string in batch: {node_dict}"
)
node_key = result["node_id"]
original_key = lookup.get(node_key)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
nodes_dict[original_key] = node_dict
nodes_dict[result["node_id"]] = node_dict
return nodes_dict
@ -3862,17 +3865,13 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids:
return {}
seen: set[str] = set()
seen = set()
unique_ids: list[str] = []
lookup: dict[str, str] = {}
requested: set[str] = set()
for nid in node_ids:
if nid not in seen:
seen.add(nid)
unique_ids.append(nid)
requested.add(nid)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
n = self._normalize_node_id(nid)
if n not in seen:
seen.add(n)
unique_ids.append(n)
out_degrees = {}
in_degrees = {}
@ -3924,16 +3923,8 @@ class PGGraphStorage(BaseGraphStorage):
node_id = row["node_id"]
if not node_id:
continue
node_key = node_id
original_key = lookup.get(node_key)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
degrees_dict = {}
for node_id in node_ids:
@ -4062,7 +4053,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
f"Failed to parse edge properties string: {edge_props}"
)
continue
@ -4078,7 +4069,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
f"Failed to parse edge properties string: {edge_props}"
)
continue
@ -4704,20 +4695,21 @@ class PGGraphStorage(BaseGraphStorage):
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n)
DETACH DELETE n
$$) AS (result agtype)"""
async with get_graph_db_lock():
try:
drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n)
DETACH DELETE n
$$) AS (result agtype)"""
await self._query(drop_query, readonly=False)
return {
"status": "success",
"message": f"workspace '{self.workspace}' graph data dropped",
}
except Exception as e:
logger.error(f"[{self.workspace}] Error dropping graph: {e}")
return {"status": "error", "message": str(e)}
await self._query(drop_query, readonly=False)
return {
"status": "success",
"message": f"workspace '{self.workspace}' graph data dropped",
}
except Exception as e:
logger.error(f"[{self.workspace}] Error dropping graph: {e}")
return {"status": "error", "message": str(e)}
# Note: Order matters! More specific namespaces (e.g., "full_entities") must come before