Remove unused chunk-based node/edge retrieval methods

(cherry picked from commit 807d2461d3)
This commit is contained in:
yangdx 2025-11-06 18:17:10 +08:00 committed by Raphaël MANSUY
parent ce702ccb2f
commit 211dbc3f78
5 changed files with 225 additions and 495 deletions

View file

@ -19,7 +19,6 @@ from typing import (
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph
from .constants import ( from .constants import (
GRAPH_FIELD_SEP,
DEFAULT_TOP_K, DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K, DEFAULT_CHUNK_TOP_K,
DEFAULT_MAX_ENTITY_TOKENS, DEFAULT_MAX_ENTITY_TOKENS,
@ -551,56 +550,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else [] result[node_id] = edges if edges is not None else []
return result return result
@abstractmethod
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
@abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
# Default implementation iterates through all nodes and their edges, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_edges = []
all_labels = await self.get_all_labels()
processed_edges = set()
for label in all_labels:
edges = await self.get_node_edges(label)
if edges:
for src_id, tgt_id in edges:
# Avoid processing the same edge twice in an undirected graph
edge_tuple = tuple(sorted((src_id, tgt_id)))
if edge_tuple in processed_edges:
continue
processed_edges.add(edge_tuple)
edge = await self.get_edge(src_id, tgt_id)
if edge and "source_id" in edge:
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
# Add source and target to the edge dict for easier processing later
edge_with_nodes = edge.copy()
edge_with_nodes["source"] = src_id
edge_with_nodes["target"] = tgt_id
all_edges.append(edge_with_nodes)
return all_edges
@abstractmethod @abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph. """Insert a new node or update an existing node in the graph.

View file

@ -8,7 +8,7 @@ import configparser
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -101,9 +101,10 @@ class MemgraphStorage(BaseGraphStorage):
raise raise
async def finalize(self): async def finalize(self):
if self._driver is not None: async with get_graph_db_lock():
await self._driver.close() if self._driver is not None:
self._driver = None await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.finalize() await self.finalize()
@ -132,7 +133,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
@ -146,10 +146,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
) )
if result is not None: await result.consume() # Ensure the result is consumed even on error
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
@ -173,7 +170,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = ( query = (
@ -194,10 +190,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
if result is not None: await result.consume() # Ensure the result is consumed even on error
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
@ -319,7 +312,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f""" query = f"""
@ -336,10 +328,7 @@ class MemgraphStorage(BaseGraphStorage):
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
if result is not None: await result.consume() # Ensure the result is consumed even on error
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -363,7 +352,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
results = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -401,10 +389,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
if results is not None: await results.consume() # Ensure results are consumed even on error
await (
results.consume()
) # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -434,7 +419,6 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f""" query = f"""
@ -467,10 +451,7 @@ class MemgraphStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
) )
if result is not None: await result.consume() # Ensure the result is consumed even on error
await (
result.consume()
) # Ensure the result is consumed even on error
raise raise
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
@ -761,21 +742,22 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
try: async with get_graph_db_lock():
async with self._driver.session(database=self._DATABASE) as session: try:
workspace_label = self._get_workspace_label() async with self._driver.session(database=self._DATABASE) as session:
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" workspace_label = self._get_workspace_label()
result = await session.run(query) query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
await result.consume() result = await session.run(query)
logger.info( await result.consume()
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" logger.info(
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
) )
return {"status": "success", "message": "workspace data dropped"} return {"status": "error", "message": str(e)}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes. """Get the total degree (sum of relationships) of two nodes.
@ -1048,7 +1030,6 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
async with self._driver.session( async with self._driver.session(
@ -1075,8 +1056,6 @@ class MemgraphStorage(BaseGraphStorage):
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}") logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
if result is not None:
await result.consume()
return [] return []
async def search_labels(self, query: str, limit: int = 50) -> list[str]: async def search_labels(self, query: str, limit: int = 50) -> list[str]:
@ -1099,7 +1078,6 @@ class MemgraphStorage(BaseGraphStorage):
if not query_lower: if not query_lower:
return [] return []
result = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
async with self._driver.session( async with self._driver.session(
@ -1133,6 +1111,4 @@ class MemgraphStorage(BaseGraphStorage):
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
if result is not None:
await result.consume()
return [] return []

View file

@ -1031,45 +1031,6 @@ class MongoGraphStorage(BaseGraphStorage):
return result return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if not chunk_ids:
return []
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
return [doc async for doc in cursor]
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if not chunk_ids:
return []
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
edges = []
async for edge in cursor:
edge["source"] = edge["source_node_id"]
edge["target"] = edge["target_node_id"]
edges.append(edge)
return edges
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# UPSERTS # UPSERTS

View file

@ -16,7 +16,7 @@ import logging
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -340,9 +340,10 @@ class Neo4JStorage(BaseGraphStorage):
async def finalize(self): async def finalize(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
if self._driver: async with get_graph_db_lock():
await self._driver.close() if self._driver:
self._driver = None await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits""" """Ensure driver is closed when context manager exits"""
@ -352,20 +353,6 @@ class Neo4JStorage(BaseGraphStorage):
# Neo4J handles persistence automatically # Neo4J handles persistence automatically
pass pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
""" """
Check if a node with the given label exists in the database Check if a node with the given label exists in the database
@ -384,7 +371,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id) result = await session.run(query, entity_id=node_id)
@ -395,24 +381,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
) )
if result is not None: await result.consume() # Ensure results are consumed even on error
await result.consume() # Ensure results are consumed even on error
raise raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
""" """
Check if an edge exists between two nodes Check if an edge exists between two nodes
@ -432,7 +403,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = ( query = (
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
@ -450,24 +420,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
if result is not None: await result.consume() # Ensure results are consumed even on error
await result.consume() # Ensure results are consumed even on error
raise raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties """Get node by its label identifier, return only node properties
@ -521,20 +476,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
@ -571,20 +512,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Make sure to consume the result fully await result.consume() # Make sure to consume the result fully
return nodes return nodes
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label. """Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node. If multiple nodes have the same label, returns the degree of the first node.
@ -633,20 +560,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
@ -731,20 +644,6 @@ class Neo4JStorage(BaseGraphStorage):
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees return edge_degrees
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
@ -832,20 +731,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edges_batch( async def get_edges_batch(
self, pairs: list[dict[str, str]] self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]: ) -> dict[tuple[str, str], dict]:
@ -896,20 +781,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() await result.consume()
return edges_dict return edges_dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label. """Retrieves all edges (relationships) for a particular node identified by its label.
@ -928,7 +799,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
results = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -966,10 +836,7 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
if results is not None: await results.consume() # Ensure results are consumed even on error
await (
results.consume()
) # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -977,20 +844,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
self, node_ids: list[str] self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]:
@ -1739,7 +1592,6 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = f""" query = f"""
MATCH (n:`{workspace_label}`) MATCH (n:`{workspace_label}`)
@ -1764,8 +1616,7 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting popular labels: {str(e)}" f"[{self.workspace}] Error getting popular labels: {str(e)}"
) )
if result is not None: await result.consume()
await result.consume()
raise raise
async def search_labels(self, query: str, limit: int = 50) -> list[str]: async def search_labels(self, query: str, limit: int = 50) -> list[str]:
@ -1912,23 +1763,24 @@ class Neo4JStorage(BaseGraphStorage):
- On success: {"status": "success", "message": "workspace data dropped"} - On success: {"status": "success", "message": "workspace data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
workspace_label = self._get_workspace_label() async with get_graph_db_lock():
try: workspace_label = self._get_workspace_label()
async with self._driver.session(database=self._DATABASE) as session: try:
# Delete all nodes and relationships in current workspace only async with self._driver.session(database=self._DATABASE) as session:
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" # Delete all nodes and relationships in current workspace only
result = await session.run(query) query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
await result.consume() # Ensure result is fully consumed result = await session.run(query)
await result.consume() # Ensure result is fully consumed
# logger.debug( # logger.debug(
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
# ) # )
return { return {
"status": "success", "status": "success",
"message": f"workspace '{workspace_label}' data dropped", "message": f"workspace '{workspace_label}' data dropped",
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -33,7 +33,7 @@ from ..base import (
) )
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..kg.shared_storage import get_data_init_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock
import pipmaster as pm import pipmaster as pm
@ -77,9 +77,6 @@ class PostgreSQLDB:
self.hnsw_m = config.get("hnsw_m") self.hnsw_m = config.get("hnsw_m")
self.hnsw_ef = config.get("hnsw_ef") self.hnsw_ef = config.get("hnsw_ef")
self.ivfflat_lists = config.get("ivfflat_lists") self.ivfflat_lists = config.get("ivfflat_lists")
self.vchordrq_build_options = config.get("vchordrq_build_options")
self.vchordrq_probes = config.get("vchordrq_probes")
self.vchordrq_epsilon = config.get("vchordrq_epsilon")
# Server settings # Server settings
self.server_settings = config.get("server_settings") self.server_settings = config.get("server_settings")
@ -365,8 +362,7 @@ class PostgreSQLDB:
await self.configure_age(connection, graph_name) await self.configure_age(connection, graph_name)
elif with_age and not graph_name: elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True") raise ValueError("Graph name is required when with_age is True")
if self.vector_index_type == "VCHORDRQ":
await self.configure_vchordrq(connection)
return await operation(connection) return await operation(connection)
@staticmethod @staticmethod
@ -383,7 +379,7 @@ class PostgreSQLDB:
async def configure_age_extension(connection: asyncpg.Connection) -> None: async def configure_age_extension(connection: asyncpg.Connection) -> None:
"""Create AGE extension if it doesn't exist for graph operations.""" """Create AGE extension if it doesn't exist for graph operations."""
try: try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS AGE CASCADE") # type: ignore await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
logger.info("PostgreSQL, AGE extension enabled") logger.info("PostgreSQL, AGE extension enabled")
except Exception as e: except Exception as e:
logger.warning(f"Could not create AGE extension: {e}") logger.warning(f"Could not create AGE extension: {e}")
@ -412,14 +408,6 @@ class PostgreSQLDB:
): ):
pass pass
async def configure_vchordrq(self, connection: asyncpg.Connection) -> None:
"""Configure VCHORDRQ extension for vector similarity search."""
try:
await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'")
await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}")
except Exception as e:
logger.error(f"Failed to set vchordrq.probes or vchordrq.epsilon: {e}")
async def _migrate_llm_cache_schema(self): async def _migrate_llm_cache_schema(self):
"""Migrate LLM cache schema: add new columns and remove deprecated mode field""" """Migrate LLM cache schema: add new columns and remove deprecated mode field"""
try: try:
@ -1154,12 +1142,19 @@ class PostgreSQLDB:
f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}" f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}"
) )
try: try:
if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]: if self.vector_index_type == "HNSW":
await self._create_vector_indexes() await self._create_hnsw_vector_indexes()
elif self.vector_index_type == "IVFFLAT":
await self._create_ivfflat_vector_indexes()
elif self.vector_index_type == "FLAT":
logger.warning(
"FLAT index type is not supported by pgvector. Skipping vector index creation. "
"Please use 'HNSW' or 'IVFFLAT' instead."
)
else: else:
logger.warning( logger.warning(
"Doesn't support this vector index type: {self.vector_index_type}. " "Doesn't support this vector index type: {self.vector_index_type}. "
"Supported types: HNSW, IVFFLAT, VCHORDRQ" "Supported types: HNSW, IVFFLAT"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -1366,39 +1361,21 @@ class PostgreSQLDB:
except Exception as e: except Exception as e:
logger.warning(f"Failed to create index {index['name']}: {e}") logger.warning(f"Failed to create index {index['name']}: {e}")
async def _create_vector_indexes(self): async def _create_hnsw_vector_indexes(self):
vdb_tables = [ vdb_tables = [
"LIGHTRAG_VDB_CHUNKS", "LIGHTRAG_VDB_CHUNKS",
"LIGHTRAG_VDB_ENTITY", "LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION", "LIGHTRAG_VDB_RELATION",
] ]
create_sql = {
"HNSW": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING hnsw (content_vector vector_cosine_ops)
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
""",
"IVFFLAT": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
""",
"VCHORDRQ": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING vchordrq (content_vector vector_cosine_ops)
{f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''}
""",
}
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
for k in vdb_tables: for k in vdb_tables:
vector_index_name = ( vector_index_name = f"idx_{k.lower()}_hnsw_cosine"
f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
)
check_vector_index_sql = f""" check_vector_index_sql = f"""
SELECT 1 FROM pg_indexes SELECT 1 FROM pg_indexes
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' WHERE indexname = '{vector_index_name}'
AND tablename = '{k.lower()}'
""" """
try: try:
vector_index_exists = await self.query(check_vector_index_sql) vector_index_exists = await self.query(check_vector_index_sql)
@ -1407,24 +1384,64 @@ class PostgreSQLDB:
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql) await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}") logger.debug(f"Ensured vector dimension for {k}")
logger.info(
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" create_vector_index_sql = f"""
) CREATE INDEX {vector_index_name}
await self.execute( ON {k} USING hnsw (content_vector vector_cosine_ops)
create_sql[self.vector_index_type].format( WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
vector_index_name=vector_index_name, k=k """
) logger.info(f"Creating hnsw index {vector_index_name} on table {k}")
) await self.execute(create_vector_index_sql)
logger.info( logger.info(
f"Successfully created vector index {vector_index_name} on table {k}" f"Successfully created vector index {vector_index_name} on table {k}"
) )
else: else:
logger.info( logger.info(
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" f"HNSW vector index {vector_index_name} already exists on table {k}"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to create vector index on table {k}, Got: {e}") logger.error(f"Failed to create vector index on table {k}, Got: {e}")
async def _create_ivfflat_vector_indexes(self):
vdb_tables = [
"LIGHTRAG_VDB_CHUNKS",
"LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION",
]
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
for k in vdb_tables:
index_name = f"idx_{k.lower()}_ivfflat_cosine"
check_index_sql = f"""
SELECT 1 FROM pg_indexes
WHERE indexname = '{index_name}' AND tablename = '{k.lower()}'
"""
try:
exists = await self.query(check_index_sql)
if not exists:
# Only set vector dimension when index doesn't exist
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}")
create_sql = f"""
CREATE INDEX {index_name}
ON {k} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
"""
logger.info(f"Creating ivfflat index {index_name} on table {k}")
await self.execute(create_sql)
logger.info(
f"Successfully created ivfflat index {index_name} on table {k}"
)
else:
logger.info(
f"Ivfflat vector index {index_name} already exists on table {k}"
)
except Exception as e:
logger.error(f"Failed to create ivfflat index on {k}: {e}")
async def query( async def query(
self, self,
sql: str, sql: str,
@ -1579,20 +1596,6 @@ class ClientManager:
config.get("postgres", "ivfflat_lists", fallback="100"), config.get("postgres", "ivfflat_lists", fallback="100"),
) )
), ),
"vchordrq_build_options": os.environ.get(
"POSTGRES_VCHORDRQ_BUILD_OPTIONS",
config.get("postgres", "vchordrq_build_options", fallback=""),
),
"vchordrq_probes": os.environ.get(
"POSTGRES_VCHORDRQ_PROBES",
config.get("postgres", "vchordrq_probes", fallback=""),
),
"vchordrq_epsilon": float(
os.environ.get(
"POSTGRES_VCHORDRQ_EPSILON",
config.get("postgres", "vchordrq_epsilon", fallback="1.9"),
)
),
# Server settings for Supabase # Server settings for Supabase
"server_settings": os.environ.get( "server_settings": os.environ.get(
"POSTGRES_SERVER_SETTINGS", "POSTGRES_SERVER_SETTINGS",
@ -1699,9 +1702,10 @@ class PGKVStorage(BaseKVStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock():
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@ -2143,21 +2147,22 @@ class PGKVStorage(BaseKVStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
try: async with get_storage_lock():
table_name = namespace_to_table_name(self.namespace) try:
if not table_name: table_name = namespace_to_table_name(self.namespace)
return { if not table_name:
"status": "error", return {
"message": f"Unknown namespace: {self.namespace}", "status": "error",
} "message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -2192,9 +2197,10 @@ class PGVectorStorage(BaseVectorStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock():
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks( def _upsert_chunks(
self, item: dict[str, Any], current_time: datetime.datetime self, item: dict[str, Any], current_time: datetime.datetime
@ -2530,21 +2536,22 @@ class PGVectorStorage(BaseVectorStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
try: async with get_storage_lock():
table_name = namespace_to_table_name(self.namespace) try:
if not table_name: table_name = namespace_to_table_name(self.namespace)
return { if not table_name:
"status": "error", return {
"message": f"Unknown namespace: {self.namespace}", "status": "error",
} "message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -2579,9 +2586,10 @@ class PGDocStatusStorage(DocStatusStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock():
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
@ -3156,21 +3164,22 @@ class PGDocStatusStorage(DocStatusStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
try: async with get_storage_lock():
table_name = namespace_to_table_name(self.namespace) try:
if not table_name: table_name = namespace_to_table_name(self.namespace)
return { if not table_name:
"status": "error", return {
"message": f"Unknown namespace: {self.namespace}", "status": "error",
} "message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
class PGGraphQueryException(Exception): class PGGraphQueryException(Exception):
@ -3302,9 +3311,10 @@ class PGGraphStorage(BaseGraphStorage):
) )
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_graph_db_lock():
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# PG handles persistence automatically # PG handles persistence automatically
@ -3558,13 +3568,17 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties""" """Get node by its label identifier, return only node properties"""
result = await self.get_nodes_batch(node_ids=[node_id]) label = self._normalize_node_id(node_id)
result = await self.get_nodes_batch(node_ids=[label])
if result and node_id in result: if result and node_id in result:
return result[node_id] return result[node_id]
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
result = await self.node_degrees_batch(node_ids=[node_id]) label = self._normalize_node_id(node_id)
result = await self.node_degrees_batch(node_ids=[label])
if result and node_id in result: if result and node_id in result:
return result[node_id] return result[node_id]
@ -3577,11 +3591,12 @@ class PGGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get edge properties between two nodes""" """Get edge properties between two nodes"""
result = await self.get_edges_batch( src_label = self._normalize_node_id(source_node_id)
[{"src": source_node_id, "tgt": target_node_id}] tgt_label = self._normalize_node_id(target_node_id)
)
if result and (source_node_id, target_node_id) in result: result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
return result[(source_node_id, target_node_id)] if result and (src_label, tgt_label) in result:
return result[(src_label, tgt_label)]
return None return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -3779,17 +3794,13 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
seen: set[str] = set() seen = set()
unique_ids: list[str] = [] unique_ids = []
lookup: dict[str, str] = {}
requested: set[str] = set()
for nid in node_ids: for nid in node_ids:
if nid not in seen: nid_norm = self._normalize_node_id(nid)
seen.add(nid) if nid_norm not in seen:
unique_ids.append(nid) seen.add(nid_norm)
requested.add(nid) unique_ids.append(nid_norm)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
# Build result dictionary # Build result dictionary
nodes_dict = {} nodes_dict = {}
@ -3828,18 +3839,10 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = json.loads(node_dict) node_dict = json.loads(node_dict)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" f"Failed to parse node string in batch: {node_dict}"
) )
node_key = result["node_id"] nodes_dict[result["node_id"]] = node_dict
original_key = lookup.get(node_key)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
nodes_dict[original_key] = node_dict
return nodes_dict return nodes_dict
@ -3862,17 +3865,13 @@ class PGGraphStorage(BaseGraphStorage):
if not node_ids: if not node_ids:
return {} return {}
seen: set[str] = set() seen = set()
unique_ids: list[str] = [] unique_ids: list[str] = []
lookup: dict[str, str] = {}
requested: set[str] = set()
for nid in node_ids: for nid in node_ids:
if nid not in seen: n = self._normalize_node_id(nid)
seen.add(nid) if n not in seen:
unique_ids.append(nid) seen.add(n)
requested.add(nid) unique_ids.append(n)
lookup[nid] = nid
lookup[self._normalize_node_id(nid)] = nid
out_degrees = {} out_degrees = {}
in_degrees = {} in_degrees = {}
@ -3924,16 +3923,8 @@ class PGGraphStorage(BaseGraphStorage):
node_id = row["node_id"] node_id = row["node_id"]
if not node_id: if not node_id:
continue continue
node_key = node_id out_degrees[node_id] = int(row.get("out_degree", 0) or 0)
original_key = lookup.get(node_key) in_degrees[node_id] = int(row.get("in_degree", 0) or 0)
if original_key is None:
logger.warning(
f"[{self.workspace}] Node {node_key} not found in lookup map"
)
original_key = node_key
if original_key in requested:
out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
degrees_dict = {} degrees_dict = {}
for node_id in node_ids: for node_id in node_ids:
@ -4062,7 +4053,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props) edge_props = json.loads(edge_props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"[{self.workspace}]Failed to parse edge properties string: {edge_props}" f"Failed to parse edge properties string: {edge_props}"
) )
continue continue
@ -4078,7 +4069,7 @@ class PGGraphStorage(BaseGraphStorage):
edge_props = json.loads(edge_props) edge_props = json.loads(edge_props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" f"Failed to parse edge properties string: {edge_props}"
) )
continue continue
@ -4704,20 +4695,21 @@ class PGGraphStorage(BaseGraphStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
try: async with get_graph_db_lock():
drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ try:
MATCH (n) drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
DETACH DELETE n MATCH (n)
$$) AS (result agtype)""" DETACH DELETE n
$$) AS (result agtype)"""
await self._query(drop_query, readonly=False) await self._query(drop_query, readonly=False)
return { return {
"status": "success", "status": "success",
"message": f"workspace '{self.workspace}' graph data dropped", "message": f"workspace '{self.workspace}' graph data dropped",
} }
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error dropping graph: {e}") logger.error(f"[{self.workspace}] Error dropping graph: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
# Note: Order matters! More specific namespaces (e.g., "full_entities") must come before # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before