undo changes for graph engines
This commit is contained in:
parent
e68a89f737
commit
e87b77fda6
3 changed files with 155 additions and 218 deletions
|
|
@ -10,7 +10,7 @@ from kuzu.database import Database
|
|||
from datetime import datetime, timezone
|
||||
from contextlib import asynccontextmanager
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
|
|
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||
from cognee.tasks.temporal_graph.models import Timestamp
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -146,21 +146,15 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to initialize Kuzu database: {e}")
|
||||
raise e
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
"""Get the connection to the Kuzu database."""
|
||||
if not self.connection:
|
||||
raise RuntimeError("Kuzu database connection not initialized")
|
||||
return self.connection
|
||||
|
||||
async def push_to_s3(self) -> None:
|
||||
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
s3_file_storage = S3FileStorage("")
|
||||
|
||||
if self._get_connection():
|
||||
if self.connection:
|
||||
async with self.KUZU_ASYNC_LOCK:
|
||||
self._get_connection().execute("CHECKPOINT;")
|
||||
self.connection.execute("CHECKPOINT;")
|
||||
|
||||
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
||||
|
||||
|
|
@ -173,9 +167,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
except FileNotFoundError:
|
||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
||||
"""
|
||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||
|
||||
|
|
@ -198,32 +190,23 @@ class KuzuAdapter(GraphDBInterface):
|
|||
loop = asyncio.get_running_loop()
|
||||
params = params or {}
|
||||
|
||||
def blocking_query() -> List[Tuple[Any, ...]]:
|
||||
def blocking_query():
|
||||
try:
|
||||
if not self._get_connection():
|
||||
if not self.connection:
|
||||
logger.debug("Reconnecting to Kuzu database...")
|
||||
self._initialize_connection()
|
||||
|
||||
if not self._get_connection():
|
||||
raise RuntimeError("Failed to establish database connection")
|
||||
|
||||
result = self._get_connection().execute(query, params)
|
||||
result = self.connection.execute(query, params)
|
||||
rows = []
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
# Handle QueryResult vs List[QueryResult] union type
|
||||
for single_result in result:
|
||||
while single_result.has_next():
|
||||
row = single_result.get_next()
|
||||
processed_rows = []
|
||||
for val in row:
|
||||
if hasattr(val, "as_py"):
|
||||
val = val.as_py()
|
||||
processed_rows.append(val)
|
||||
rows.append(tuple(processed_rows))
|
||||
|
||||
while result.has_next():
|
||||
row = result.get_next()
|
||||
processed_rows = []
|
||||
for val in row:
|
||||
if hasattr(val, "as_py"):
|
||||
val = val.as_py()
|
||||
processed_rows.append(val)
|
||||
rows.append(tuple(processed_rows))
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Query execution failed: {str(e)}")
|
||||
|
|
@ -232,7 +215,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
return await loop.run_in_executor(self.executor, blocking_query)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncGenerator[Optional[Connection], None]:
|
||||
async def get_session(self):
|
||||
"""
|
||||
Get a database session.
|
||||
|
||||
|
|
@ -241,7 +224,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
and on exit performs cleanup if necessary.
|
||||
"""
|
||||
try:
|
||||
yield self._get_connection()
|
||||
yield self.connection
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
|
@ -272,7 +255,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
def _edge_query_and_params(
|
||||
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
) -> Tuple[str, dict]:
|
||||
"""Build the edge creation query and parameters."""
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
query = """
|
||||
|
|
@ -322,9 +305,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
result = await self.query(query_str, {"id": node_id})
|
||||
return result[0][0] if result else False
|
||||
|
||||
async def add_node(
|
||||
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
async def add_node(self, node: DataPoint) -> None:
|
||||
"""
|
||||
Add a single node to the graph if it doesn't exist.
|
||||
|
||||
|
|
@ -338,32 +319,20 @@ class KuzuAdapter(GraphDBInterface):
|
|||
- node (DataPoint): The node to be added, represented as a DataPoint.
|
||||
"""
|
||||
try:
|
||||
if isinstance(node, str):
|
||||
# Handle string node ID with properties parameter
|
||||
node_properties = properties or {}
|
||||
core_properties = {
|
||||
"id": node,
|
||||
"name": str(node_properties.get("name", "")),
|
||||
"type": str(node_properties.get("type", "")),
|
||||
}
|
||||
# Use the passed properties, excluding core fields
|
||||
other_properties = {
|
||||
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||
}
|
||||
else:
|
||||
# Handle DataPoint object
|
||||
node_properties = node.model_dump()
|
||||
core_properties = {
|
||||
"id": str(node_properties.get("id", "")),
|
||||
"name": str(node_properties.get("name", "")),
|
||||
"type": str(node_properties.get("type", "")),
|
||||
}
|
||||
# Remove core fields from other properties
|
||||
other_properties = {
|
||||
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||
}
|
||||
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
||||
|
||||
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
||||
# Extract core fields with defaults if not present
|
||||
core_properties = {
|
||||
"id": str(properties.get("id", "")),
|
||||
"name": str(properties.get("name", "")),
|
||||
"type": str(properties.get("type", "")),
|
||||
}
|
||||
|
||||
# Remove core fields from other properties
|
||||
for key in core_properties:
|
||||
properties.pop(key, None)
|
||||
|
||||
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
|
||||
|
||||
# Add timestamps for new node
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
|
@ -391,7 +360,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to add node: {e}")
|
||||
raise
|
||||
|
||||
@record_graph_changes # type: ignore
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||
"""
|
||||
Add multiple nodes to the graph in a batch operation.
|
||||
|
|
@ -599,9 +568,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
)
|
||||
return result[0][0] if result else False
|
||||
|
||||
async def has_edges(
|
||||
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
|
||||
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Check if multiple edges exist in a batch operation.
|
||||
|
||||
|
|
@ -632,7 +599,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"to_id": str(to_node), # Ensure string type
|
||||
"relationship_name": str(edge_label), # Ensure string type
|
||||
}
|
||||
for from_node, to_node, edge_label, _ in edges
|
||||
for from_node, to_node, edge_label in edges
|
||||
]
|
||||
|
||||
# Batch check query with direct string comparison
|
||||
|
|
@ -648,21 +615,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
results = await self.query(query, {"edges": edge_params})
|
||||
|
||||
# Convert results back to tuples and ensure string types
|
||||
# Find the original edge properties for each existing edge
|
||||
# TODO: get review on this
|
||||
existing_edges = []
|
||||
for row in results:
|
||||
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
|
||||
# Find the original properties from the input edges
|
||||
original_props = {}
|
||||
for orig_from, orig_to, orig_rel, orig_props in edges:
|
||||
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
|
||||
original_props = orig_props
|
||||
break
|
||||
existing_edges.append((from_id, to_id, rel_name, original_props))
|
||||
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
|
||||
|
||||
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
|
||||
# TODO: otherwise, we can just return dummy properties since they are not used apparently
|
||||
return existing_edges
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -671,10 +626,10 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
async def add_edge(
|
||||
self,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
from_node: str,
|
||||
to_node: str,
|
||||
relationship_name: str,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
edge_properties: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Add an edge between two nodes.
|
||||
|
|
@ -686,23 +641,23 @@ class KuzuAdapter(GraphDBInterface):
|
|||
Parameters:
|
||||
-----------
|
||||
|
||||
- source_id (str): The identifier of the source node from which the edge originates.
|
||||
- target_id (str): The identifier of the target node to which the edge points.
|
||||
- from_node (str): The identifier of the source node from which the edge originates.
|
||||
- to_node (str): The identifier of the target node to which the edge points.
|
||||
- relationship_name (str): The label of the edge to be created, representing the
|
||||
relationship name.
|
||||
- properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge.
|
||||
(default None)
|
||||
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
|
||||
(default {})
|
||||
"""
|
||||
try:
|
||||
query, params = self._edge_query_and_params(
|
||||
source_id, target_id, relationship_name, properties or {}
|
||||
from_node, to_node, relationship_name, edge_properties
|
||||
)
|
||||
await self.query(query, params)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add edge: {e}")
|
||||
raise
|
||||
|
||||
@record_graph_changes # type: ignore
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||
"""
|
||||
Add multiple edges in a batch operation.
|
||||
|
|
@ -757,7 +712,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to add edges in batch: {e}")
|
||||
raise
|
||||
|
||||
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
|
||||
"""
|
||||
Get all edges connected to a node.
|
||||
|
||||
|
|
@ -772,8 +727,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
Returns:
|
||||
--------
|
||||
|
||||
- List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each
|
||||
tuple contains (source_id, relationship_name, target_id, edge_properties).
|
||||
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
|
||||
tuple contains (source_node, relationship_name, target_node), with source_node and
|
||||
target_node as dictionaries of node properties.
|
||||
"""
|
||||
query_str = """
|
||||
MATCH (n:Node)-[r]-(m:Node)
|
||||
|
|
@ -794,14 +750,12 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"""
|
||||
try:
|
||||
results = await self.query(query_str, {"node_id": node_id})
|
||||
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
|
||||
edges = []
|
||||
for row in results:
|
||||
if row and len(row) == 3:
|
||||
source_node = self._parse_node_properties(row[0])
|
||||
relationship_name = row[1]
|
||||
target_node = self._parse_node_properties(row[2])
|
||||
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
|
||||
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
|
||||
edges.append((source_node, row[1], target_node))
|
||||
return edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
||||
|
|
@ -1023,7 +977,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
return []
|
||||
|
||||
async def get_connections(
|
||||
self, node_id: Union[str, UUID]
|
||||
self, node_id: str
|
||||
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
|
||||
"""
|
||||
Get all nodes connected to a given node.
|
||||
|
|
@ -1065,9 +1019,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
}
|
||||
"""
|
||||
try:
|
||||
# Convert UUID to string if needed
|
||||
node_id_str = str(node_id)
|
||||
results = await self.query(query_str, {"node_id": node_id_str})
|
||||
results = await self.query(query_str, {"node_id": node_id})
|
||||
edges = []
|
||||
for row in results:
|
||||
if row and len(row) == 3:
|
||||
|
|
@ -1225,7 +1177,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]:
|
||||
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
||||
"""
|
||||
Get subgraph for a set of nodes based on type and names.
|
||||
|
||||
|
|
@ -1273,9 +1225,9 @@ class KuzuAdapter(GraphDBInterface):
|
|||
RETURN n.id, n.name, n.type, n.properties
|
||||
"""
|
||||
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
||||
nodes: List[Tuple[str, Dict[str, Any]]] = []
|
||||
nodes: List[Tuple[str, dict]] = []
|
||||
for node_id, name, typ, props in node_rows:
|
||||
data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ}
|
||||
data = {"id": node_id, "name": name, "type": typ}
|
||||
if props:
|
||||
try:
|
||||
data.update(json.loads(props))
|
||||
|
|
@ -1289,22 +1241,22 @@ class KuzuAdapter(GraphDBInterface):
|
|||
RETURN a.id, b.id, r.relationship_name, r.properties
|
||||
"""
|
||||
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
||||
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
|
||||
edges: List[Tuple[str, str, str, dict]] = []
|
||||
for from_id, to_id, rel_type, props in edge_rows:
|
||||
edge_data: Dict[str, Any] = {}
|
||||
data = {}
|
||||
if props:
|
||||
try:
|
||||
edge_data = json.loads(props)
|
||||
data = json.loads(props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
||||
|
||||
edges.append((from_id, to_id, rel_type, edge_data))
|
||||
edges.append((from_id, to_id, rel_type, data))
|
||||
|
||||
return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
):
|
||||
"""
|
||||
Get filtered nodes and relationships based on attributes.
|
||||
|
||||
|
|
@ -1347,7 +1299,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
)
|
||||
return ([n[0] for n in nodes], [e[0] for e in edges])
|
||||
|
||||
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics on graph structure and connectivity.
|
||||
|
||||
|
|
@ -1370,8 +1322,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
try:
|
||||
# Get basic graph data
|
||||
nodes, edges = await self.get_model_independent_graph_data()
|
||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
|
||||
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
|
||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
|
||||
num_edges = len(edges[0]["elements"]) if edges else 0
|
||||
|
||||
# Calculate mandatory metrics
|
||||
mandatory_metrics = {
|
||||
|
|
@ -1531,8 +1483,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
It raises exceptions for failures occurring during deletion processes.
|
||||
"""
|
||||
try:
|
||||
if self._get_connection():
|
||||
self._get_connection().close()
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
|
@ -1563,7 +1515,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
occur during file deletions or initializations carefully.
|
||||
"""
|
||||
try:
|
||||
if self._get_connection():
|
||||
if self.connection:
|
||||
self.connection = None
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
|
@ -1579,30 +1531,20 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
# Reinitialize the database
|
||||
self._initialize_connection()
|
||||
|
||||
if not self._get_connection():
|
||||
raise RuntimeError("Failed to establish database connection")
|
||||
|
||||
# Verify the database is empty
|
||||
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
for single_result in result:
|
||||
_next = single_result.get_next()
|
||||
if not isinstance(_next, list):
|
||||
raise RuntimeError("Expected list of results")
|
||||
count = _next[0] if _next else 0
|
||||
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
|
||||
count = result.get_next()[0] if result.has_next() else 0
|
||||
if count > 0:
|
||||
logger.warning(
|
||||
f"Database still contains {count} nodes after clearing, forcing deletion"
|
||||
)
|
||||
self._get_connection().execute("MATCH (n:Node) DETACH DELETE n")
|
||||
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
|
||||
logger.info("Database cleared successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during database clearing: {e}")
|
||||
raise
|
||||
|
||||
async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_document_subgraph(self, data_id: str):
|
||||
"""
|
||||
Get all nodes that should be deleted when removing a document.
|
||||
|
||||
|
|
@ -1674,7 +1616,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"orphan_types": result[0][4],
|
||||
}
|
||||
|
||||
async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]:
|
||||
async def get_degree_one_nodes(self, node_type: str):
|
||||
"""
|
||||
Get all nodes that have only one connection.
|
||||
|
||||
|
|
@ -1827,8 +1769,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
ids: List[str] = []
|
||||
|
||||
if time_from and time_to:
|
||||
time_from_int = date_to_int(time_from)
|
||||
time_to_int = date_to_int(time_to)
|
||||
time_from = date_to_int(time_from)
|
||||
time_to = date_to_int(time_to)
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1840,13 +1782,13 @@ class KuzuAdapter(GraphDBInterface):
|
|||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||
ELSE CAST(t_str AS INT64)
|
||||
END AS t
|
||||
WHERE t >= {time_from_int}
|
||||
AND t <= {time_to_int}
|
||||
WHERE t >= {time_from}
|
||||
AND t <= {time_to}
|
||||
RETURN n.id as id
|
||||
"""
|
||||
|
||||
elif time_from:
|
||||
time_from_int = date_to_int(time_from)
|
||||
time_from = date_to_int(time_from)
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1858,12 +1800,12 @@ class KuzuAdapter(GraphDBInterface):
|
|||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||
ELSE CAST(t_str AS INT64)
|
||||
END AS t
|
||||
WHERE t >= {time_from_int}
|
||||
WHERE t >= {time_from}
|
||||
RETURN n.id as id
|
||||
"""
|
||||
|
||||
elif time_to:
|
||||
time_to_int = date_to_int(time_to)
|
||||
time_to = date_to_int(time_to)
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1875,12 +1817,12 @@ class KuzuAdapter(GraphDBInterface):
|
|||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||
ELSE CAST(t_str AS INT64)
|
||||
END AS t
|
||||
WHERE t <= {time_to_int}
|
||||
WHERE t <= {time_to}
|
||||
RETURN n.id as id
|
||||
"""
|
||||
|
||||
else:
|
||||
return ", ".join(f"'{uid}'" for uid in ids)
|
||||
return ids
|
||||
|
||||
time_nodes = await self.query(cypher)
|
||||
time_ids_list = [item[0] for item in time_nodes]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import aiohttp
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ logger = get_logger()
|
|||
class UUIDEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that handles UUID objects."""
|
||||
|
||||
def default(self, obj: Union[UUID, Any]) -> Any:
|
||||
def default(self, obj):
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
return super().default(obj)
|
||||
|
|
@ -36,7 +36,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|||
self.api_url = api_url
|
||||
self.username = username
|
||||
self.password = password
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._session = None
|
||||
self._schema_initialized = False
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
|
|
@ -45,13 +45,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
async def close(self):
|
||||
"""Close the adapter and its session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||
async def _make_request(self, endpoint: str, data: dict) -> dict:
|
||||
"""Make a request to the Kuzu API."""
|
||||
url = f"{self.api_url}{endpoint}"
|
||||
session = await self._get_session()
|
||||
|
|
@ -73,15 +73,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|||
status=response.status,
|
||||
message=error_detail,
|
||||
)
|
||||
return await response.json() # type: ignore
|
||||
return await response.json()
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"API request failed: {str(e)}")
|
||||
logger.error(f"Request data: {data}")
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[dict[str, Any]] = None
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
||||
"""Execute a Kuzu query via the REST API."""
|
||||
try:
|
||||
# Initialize schema if needed
|
||||
|
|
@ -128,7 +126,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|||
logger.error(f"Failed to check schema: {e}")
|
||||
return False
|
||||
|
||||
async def _create_schema(self) -> None:
|
||||
async def _create_schema(self):
|
||||
"""Create the required schema tables."""
|
||||
try:
|
||||
# Create Node table if it doesn't exist
|
||||
|
|
@ -182,7 +180,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|||
logger.error(f"Failed to create schema: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_schema(self) -> None:
|
||||
async def _initialize_schema(self):
|
||||
"""Initialize the database schema if it doesn't exist."""
|
||||
if self._schema_initialized:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ from neo4j import AsyncSession
|
|||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||
from cognee.tasks.temporal_graph.models import Timestamp
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
|
|
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
async def get_session(self) -> AsyncSession:
|
||||
"""
|
||||
Get a session for database operations.
|
||||
"""
|
||||
async with self.driver.session(database=self.graph_database_name) as session:
|
||||
yield session
|
||||
|
||||
@deadlock_retry() # type: ignore
|
||||
@deadlock_retry()
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -112,7 +112,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async with self.get_session() as session:
|
||||
result = await session.run(query, parameters=params)
|
||||
data = await result.data()
|
||||
# TODO: why we don't return List[Dict[str, Any]]?
|
||||
return data
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
|
|
@ -142,29 +141,21 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
return results[0]["node_exists"] if len(results) > 0 else False
|
||||
|
||||
async def add_node(
|
||||
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
async def add_node(self, node: DataPoint):
|
||||
"""
|
||||
Add a new node to the database based on the provided DataPoint object or string ID.
|
||||
Add a new node to the database based on the provided DataPoint object.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
|
||||
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
|
||||
- node (DataPoint): An instance of DataPoint representing the node to add.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The result of the query execution, typically the ID of the added node.
|
||||
"""
|
||||
if isinstance(node, str):
|
||||
# TODO: this was not handled in the original code, check if it is correct
|
||||
# Handle string node ID with properties parameter
|
||||
node_id = node
|
||||
node_label = "Node" # Default label for string nodes
|
||||
serialized_properties = self.serialize_properties(properties or {})
|
||||
else:
|
||||
# Handle DataPoint object
|
||||
node_id = str(node.id)
|
||||
node_label = type(node).__name__
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
|
||||
query = dedent(
|
||||
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
|
||||
|
|
@ -176,16 +167,16 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
params = {
|
||||
"node_id": node_id,
|
||||
"node_label": node_label,
|
||||
"node_id": str(node.id),
|
||||
"node_label": type(node).__name__,
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
await self.query(query, params)
|
||||
return await self.query(query, params)
|
||||
|
||||
@record_graph_changes # type: ignore
|
||||
@override_distributed(queued_add_nodes) # type: ignore
|
||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||
@record_graph_changes
|
||||
@override_distributed(queued_add_nodes)
|
||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||
"""
|
||||
Add multiple nodes to the database in a single query.
|
||||
|
||||
|
|
@ -210,7 +201,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||
"""
|
||||
|
||||
node_params = [
|
||||
nodes = [
|
||||
{
|
||||
"node_id": str(node.id),
|
||||
"label": type(node).__name__,
|
||||
|
|
@ -219,9 +210,10 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
for node in nodes
|
||||
]
|
||||
|
||||
await self.query(query, dict(nodes=node_params))
|
||||
results = await self.query(query, dict(nodes=nodes))
|
||||
return results
|
||||
|
||||
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def extract_node(self, node_id: str):
|
||||
"""
|
||||
Retrieve a single node from the database by its ID.
|
||||
|
||||
|
|
@ -239,7 +231,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return results[0] if len(results) > 0 else None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||||
async def extract_nodes(self, node_ids: List[str]):
|
||||
"""
|
||||
Retrieve multiple nodes from the database by their IDs.
|
||||
|
||||
|
|
@ -264,7 +256,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return [result["node"] for result in results]
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
Remove a node from the database identified by its ID.
|
||||
|
||||
|
|
@ -281,7 +273,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
|
||||
params = {"node_id": node_id}
|
||||
|
||||
await self.query(query, params)
|
||||
return await self.query(query, params)
|
||||
|
||||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||||
"""
|
||||
|
|
@ -304,18 +296,18 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
await self.query(query, params)
|
||||
return await self.query(query, params)
|
||||
|
||||
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||||
"""
|
||||
Check if an edge exists between two nodes with the specified IDs and edge label.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- source_id (str): The ID of the node from which the edge originates.
|
||||
- target_id (str): The ID of the node to which the edge points.
|
||||
- relationship_name (str): The label of the edge to check for existence.
|
||||
- from_node (UUID): The ID of the node from which the edge originates.
|
||||
- to_node (UUID): The ID of the node to which the edge points.
|
||||
- edge_label (str): The label of the edge to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -323,28 +315,27 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
- bool: True if the edge exists, otherwise False.
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
|
||||
WHERE from_node.id = $source_id AND to_node.id = $target_id
|
||||
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
|
||||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
|
||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": str(source_id),
|
||||
"target_id": str(target_id),
|
||||
"from_node_id": str(from_node),
|
||||
"to_node_id": str(to_node),
|
||||
}
|
||||
|
||||
edge_exists = await self.query(query, params)
|
||||
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
|
||||
return edge_exists
|
||||
|
||||
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
|
||||
async def has_edges(self, edges):
|
||||
"""
|
||||
Check if multiple edges exist based on provided edge criteria.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
|
||||
- edges: A list of edge specifications to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -378,24 +369,29 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
async def add_edge(
|
||||
self,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
from_node: UUID,
|
||||
to_node: UUID,
|
||||
relationship_name: str,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
edge_properties: Optional[Dict[str, Any]] = {},
|
||||
):
|
||||
"""
|
||||
Create a new edge between two nodes with specified properties.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- source_id (str): The ID of the source node of the edge.
|
||||
- target_id (str): The ID of the target node of the edge.
|
||||
- from_node (UUID): The ID of the source node of the edge.
|
||||
- to_node (UUID): The ID of the target node of the edge.
|
||||
- relationship_name (str): The type/label of the edge to create.
|
||||
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
||||
to the edge. (default None)
|
||||
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
||||
to the edge. (default {})
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The result of the query execution, typically indicating the created edge.
|
||||
"""
|
||||
serialized_properties = self.serialize_properties(properties or {})
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
|
||||
query = dedent(
|
||||
f"""\
|
||||
|
|
@ -409,13 +405,13 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
params = {
|
||||
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
||||
"to_node": str(target_id),
|
||||
"from_node": str(from_node),
|
||||
"to_node": str(to_node),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
await self.query(query, params)
|
||||
return await self.query(query, params)
|
||||
|
||||
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -449,9 +445,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return flattened
|
||||
|
||||
@record_graph_changes # type: ignore
|
||||
@override_distributed(queued_add_edges) # type: ignore
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||
@record_graph_changes
|
||||
@override_distributed(queued_add_edges)
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
"""
|
||||
Add multiple edges between nodes in a single query.
|
||||
|
||||
|
|
@ -482,10 +478,10 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
) YIELD rel
|
||||
RETURN rel"""
|
||||
|
||||
edge_params = [
|
||||
edges = [
|
||||
{
|
||||
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
||||
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
"properties": self._flatten_edge_properties(
|
||||
{
|
||||
|
|
@ -499,12 +495,13 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
]
|
||||
|
||||
try:
|
||||
await self.query(query, dict(edges=edge_params))
|
||||
results = await self.query(query, dict(edges=edges))
|
||||
return results
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||
async def get_edges(self, node_id: str):
|
||||
"""
|
||||
Retrieve all edges connected to a specified node.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue