<!-- .github/pull_request_template.md -->
## Description
This PR introduces triplet embeddings via a new
create_triplet_embeddings memify pipeline.
The pipeline reads the graph in batches, extracts properties from graph
elements based on their datapoint types, and generates combined triplet
embeddings. These embeddings are stored in the vector database as a new
collection.
Changes in This PR:
-Added a new create_triplet_embeddings memify pipeline.
-Added a new get_triplet_datapoints memify task.
-Introduced a new triplet_completion search type.
-Added full test coverage
--Unit tests: memify task, pipeline, and retriever
--Integration tests: memify task, pipeline, and retriever
--End-to-end tests: updated session history tests and multi-DB search
tests; added tests for triplet_completion and memify pipeline execution
Acceptance Criteria and Testing
Scenario 1:
-Run default add, cognify pipelines
-Run create triplet embeddings memify pipeline
-Verify the vector DB contains a non empty Triplet_text collection.
-Use the new triplet_completion search type and confirm it works
correctly.
Scenario 2:
-Run the default add and cognify pipelines.
-Do not run the triplet embeddings memify pipeline.
-Attempt to use the triplet_completion search type.
-You should receive an error indicating that the triplet embeddings
memify pipeline must be executed first.
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Triplet-based search with LLM-powered completions (TRIPLET_COMPLETION)
* Batch triplet retrieval and a triplet embeddings pipeline for
extraction, indexing, and optional background processing
* Context retrieval from triplet embeddings with optional caching and
conversation-history support
* New Triplet data type exposed for indexing and search
* **Examples**
* End-to-end example demonstrating triplet embeddings extraction and
TRIPLET_COMPLETION search
* **Tests**
* Unit and integration tests covering triplet extraction, retrieval,
embedding pipeline, and completion flows
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
2138 lines
78 KiB
Python
2138 lines
78 KiB
Python
"""Adapter for Kuzu graph database."""
|
|
|
|
import os
|
|
import json
|
|
import asyncio
|
|
import tempfile
|
|
from uuid import UUID, uuid5, NAMESPACE_OID
|
|
from kuzu import Connection
|
|
from kuzu.database import Database
|
|
from datetime import datetime, timezone
|
|
from contextlib import asynccontextmanager
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
|
|
|
from cognee.exceptions import CogneeValidationError
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.utils.run_sync import run_sync
|
|
from cognee.infrastructure.files.storage import get_file_storage
|
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|
GraphDBInterface,
|
|
record_graph_changes,
|
|
)
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.modules.storage.utils import JSONEncoder
|
|
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
|
from cognee.tasks.temporal_graph.models import Timestamp
|
|
from cognee.infrastructure.databases.cache.config import get_cache_config
|
|
|
|
logger = get_logger()
|
|
|
|
cache_config = get_cache_config()
|
|
if cache_config.shared_kuzu_lock:
|
|
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
|
|
|
|
|
class KuzuAdapter(GraphDBInterface):
|
|
"""
|
|
Adapter for Kuzu graph database operations with improved consistency and async support.
|
|
|
|
This class facilitates operations for working with the Kuzu graph database, supporting
|
|
both direct database queries and a structured asynchronous interface for node and edge
|
|
management. It contains methods for querying, adding, and deleting nodes and edges as
|
|
well as for graph metrics and data extraction.
|
|
"""
|
|
|
|
def __init__(self, db_path: str):
|
|
"""Initialize Kuzu database connection and schema."""
|
|
self.open_connections = 0
|
|
self._is_closed = False
|
|
self.db_path = db_path # Path for the database directory
|
|
self.db: Optional[Database] = None
|
|
self.connection: Optional[Connection] = None
|
|
if cache_config.shared_kuzu_lock:
|
|
self.redis_lock = get_cache_engine(
|
|
lock_key="kuzu-lock-" + str(uuid5(NAMESPACE_OID, db_path))
|
|
)
|
|
else:
|
|
self.executor = ThreadPoolExecutor()
|
|
self._initialize_connection()
|
|
self.KUZU_ASYNC_LOCK = asyncio.Lock()
|
|
self._connection_change_lock = asyncio.Lock()
|
|
|
|
def _initialize_connection(self) -> None:
|
|
"""Initialize the Kuzu database connection and schema."""
|
|
|
|
def _install_json_extension():
|
|
"""
|
|
Function handles installing of the json extension for the current Kuzu version.
|
|
This has to be done with an empty graph db before connecting to an existing database otherwise
|
|
missing json extension errors will be raised.
|
|
"""
|
|
try:
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=True) as temp_file:
|
|
temp_graph_file = temp_file.name
|
|
tmp_db = Database(
|
|
temp_graph_file,
|
|
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
|
max_db_size=4096 * 1024 * 1024,
|
|
)
|
|
tmp_db.init_database()
|
|
connection = Connection(tmp_db)
|
|
connection.execute("INSTALL JSON;")
|
|
except Exception as e:
|
|
logger.info(f"JSON extension already installed or not needed: {e}")
|
|
|
|
_install_json_extension()
|
|
|
|
try:
|
|
if "s3://" in self.db_path:
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
|
|
self.temp_graph_file = temp_file.name
|
|
|
|
run_sync(self.pull_from_s3())
|
|
|
|
self.db = Database(
|
|
self.temp_graph_file,
|
|
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
|
max_db_size=4096 * 1024 * 1024,
|
|
)
|
|
else:
|
|
# Ensure the parent directory exists before creating the database
|
|
db_dir = os.path.dirname(self.db_path)
|
|
|
|
# If db_path is just a filename, db_dir will be empty string
|
|
# In this case, use the directory containing the db_path or current directory
|
|
if not db_dir:
|
|
# If no directory in path, use the absolute path's directory
|
|
abs_path = os.path.abspath(self.db_path)
|
|
db_dir = os.path.dirname(abs_path)
|
|
|
|
file_storage = get_file_storage(db_dir)
|
|
|
|
run_sync(file_storage.ensure_directory_exists())
|
|
|
|
try:
|
|
self.db = Database(
|
|
self.db_path,
|
|
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
|
max_db_size=4096 * 1024 * 1024,
|
|
)
|
|
except RuntimeError:
|
|
from .kuzu_migrate import read_kuzu_storage_version
|
|
import kuzu
|
|
|
|
kuzu_db_version = read_kuzu_storage_version(self.db_path)
|
|
if (
|
|
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
|
|
) and kuzu_db_version != kuzu.__version__:
|
|
# Try to migrate kuzu database to latest version
|
|
from .kuzu_migrate import kuzu_migration
|
|
|
|
kuzu_migration(
|
|
new_db=self.db_path + "_new",
|
|
old_db=self.db_path,
|
|
new_version=kuzu.__version__,
|
|
old_version=kuzu_db_version,
|
|
overwrite=True,
|
|
)
|
|
|
|
self.db = Database(
|
|
self.db_path,
|
|
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
|
max_db_size=4096 * 1024 * 1024,
|
|
)
|
|
|
|
self.db.init_database()
|
|
self.connection = Connection(self.db)
|
|
|
|
try:
|
|
self.connection.execute("LOAD EXTENSION JSON;")
|
|
logger.info("Loaded JSON extension")
|
|
except Exception as e:
|
|
logger.info(f"JSON extension already loaded or unavailable: {e}")
|
|
|
|
# Create node table with essential fields and timestamp
|
|
self.connection.execute("""
|
|
CREATE NODE TABLE IF NOT EXISTS Node(
|
|
id STRING PRIMARY KEY,
|
|
name STRING,
|
|
type STRING,
|
|
created_at TIMESTAMP,
|
|
updated_at TIMESTAMP,
|
|
properties STRING
|
|
)
|
|
""")
|
|
# Create relationship table with timestamp
|
|
self.connection.execute("""
|
|
CREATE REL TABLE IF NOT EXISTS EDGE(
|
|
FROM Node TO Node,
|
|
relationship_name STRING,
|
|
created_at TIMESTAMP,
|
|
updated_at TIMESTAMP,
|
|
properties STRING
|
|
)
|
|
""")
|
|
logger.debug("Kuzu database initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Kuzu database: {e}")
|
|
raise e
|
|
|
|
async def push_to_s3(self) -> None:
|
|
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
|
|
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
|
|
|
s3_file_storage = S3FileStorage("")
|
|
|
|
if self.connection:
|
|
async with self.KUZU_ASYNC_LOCK:
|
|
self.connection.execute("CHECKPOINT;")
|
|
|
|
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
|
|
|
async def pull_from_s3(self) -> None:
|
|
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
|
|
|
s3_file_storage = S3FileStorage("")
|
|
try:
|
|
s3_file_storage.s3.get(self.db_path, self.temp_graph_file, recursive=True)
|
|
except FileNotFoundError:
|
|
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
|
|
|
async def is_empty(self) -> bool:
|
|
query = """
|
|
MATCH (n)
|
|
RETURN true
|
|
LIMIT 1;
|
|
"""
|
|
query_result = await self.query(query)
|
|
return len(query_result) == 0
|
|
|
|
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
|
"""
|
|
Execute a Kuzu query asynchronously with automatic reconnection.
|
|
|
|
This method runs a database query while managing potential reconnections. It handles
|
|
parameters in a dictionary and processes results to return structured data. The method
|
|
raises any exceptions encountered during query execution.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- query (str): The Kuzu query string to be executed.
|
|
- params (Optional[dict]): A dictionary of parameters for the query, if applicable.
|
|
(default None)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Tuple]: A list of tuples representing the query results.
|
|
"""
|
|
loop = asyncio.get_running_loop()
|
|
params = params or {}
|
|
|
|
def blocking_query():
|
|
lock_acquired = False
|
|
try:
|
|
if cache_config.shared_kuzu_lock:
|
|
self.redis_lock.acquire_lock()
|
|
lock_acquired = True
|
|
if not self.connection:
|
|
logger.info("Reconnecting to Kuzu database...")
|
|
self._initialize_connection()
|
|
|
|
result = self.connection.execute(query, params)
|
|
rows = []
|
|
|
|
while result.has_next():
|
|
row = result.get_next()
|
|
processed_rows = []
|
|
for val in row:
|
|
if hasattr(val, "as_py"):
|
|
val = val.as_py()
|
|
processed_rows.append(val)
|
|
rows.append(tuple(processed_rows))
|
|
|
|
return rows
|
|
except Exception as e:
|
|
logger.error(f"Query execution failed: {str(e)}")
|
|
raise
|
|
finally:
|
|
if cache_config.shared_kuzu_lock and lock_acquired:
|
|
try:
|
|
self.close()
|
|
finally:
|
|
self.redis_lock.release_lock()
|
|
|
|
if cache_config.shared_kuzu_lock:
|
|
async with self._connection_change_lock:
|
|
self.open_connections += 1
|
|
logger.info(f"Open connections after open: {self.open_connections}")
|
|
try:
|
|
result = blocking_query()
|
|
finally:
|
|
self.open_connections -= 1
|
|
logger.info(f"Open connections after close: {self.open_connections}")
|
|
return result
|
|
else:
|
|
result = await loop.run_in_executor(self.executor, blocking_query)
|
|
return result
|
|
|
|
def close(self):
|
|
if self.connection:
|
|
del self.connection
|
|
self.connection = None
|
|
if self.db:
|
|
del self.db
|
|
self.db = None
|
|
self._is_closed = True
|
|
logger.info("Kuzu database closed successfully")
|
|
|
|
def reopen(self):
|
|
if self._is_closed:
|
|
self._is_closed = False
|
|
self._initialize_connection()
|
|
logger.info("Kuzu database re-opened successfully")
|
|
|
|
@asynccontextmanager
|
|
async def get_session(self):
|
|
"""
|
|
Get a database session.
|
|
|
|
This provides an API-compatible session management for Kuzu, even though it does not
|
|
have built-in session management like other databases. It yields the current connection
|
|
and on exit performs cleanup if necessary.
|
|
"""
|
|
try:
|
|
yield self.connection
|
|
finally:
|
|
pass
|
|
|
|
def _parse_node(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Convert a raw node result (with JSON properties) into a dictionary."""
|
|
if data.get("properties"):
|
|
try:
|
|
props = json.loads(data["properties"])
|
|
# Remove the JSON field and merge its contents
|
|
data.pop("properties")
|
|
data.update(props)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse properties JSON for node {data.get('id')}")
|
|
return data
|
|
|
|
def _parse_node_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
if isinstance(data, dict) and "properties" in data and data["properties"]:
|
|
props = json.loads(data["properties"])
|
|
data.update(props)
|
|
del data["properties"]
|
|
return data
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse properties JSON for node {data.get('id')}")
|
|
return data
|
|
|
|
# Helper method for building edge queries
|
|
|
|
def _edge_query_and_params(
|
|
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
|
|
) -> Tuple[str, dict]:
|
|
"""Build the edge creation query and parameters."""
|
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
query = """
|
|
MATCH (from:Node), (to:Node)
|
|
WHERE from.id = $from_id AND to.id = $to_id
|
|
MERGE (from)-[r:EDGE {
|
|
relationship_name: $relationship_name
|
|
}]->(to)
|
|
ON CREATE SET
|
|
r.created_at = timestamp($created_at),
|
|
r.updated_at = timestamp($updated_at),
|
|
r.properties = $properties
|
|
ON MATCH SET
|
|
r.updated_at = timestamp($updated_at),
|
|
r.properties = $properties
|
|
"""
|
|
params = {
|
|
"from_id": from_node,
|
|
"to_id": to_node,
|
|
"relationship_name": relationship_name,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"properties": json.dumps(properties, cls=JSONEncoder),
|
|
}
|
|
return query, params
|
|
|
|
# Node Operations
|
|
|
|
async def has_node(self, node_id: str) -> bool:
|
|
"""
|
|
Check if a node exists.
|
|
|
|
This method checks for the existence of a node in the database by its identifier. It
|
|
returns a boolean indicating whether the node is present or not.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node to check.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- bool: True if the node exists, False otherwise.
|
|
"""
|
|
query_str = "MATCH (n:Node) WHERE n.id = $id RETURN COUNT(n) > 0"
|
|
result = await self.query(query_str, {"id": node_id})
|
|
return result[0][0] if result else False
|
|
|
|
async def add_node(self, node: DataPoint) -> None:
|
|
"""
|
|
Add a single node to the graph if it doesn't exist.
|
|
|
|
This method constructs and executes a query to add a node to the graph, ensuring that it
|
|
is not duplicated by checking its existence first. An error is raised if the operation
|
|
fails.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node (DataPoint): The node to be added, represented as a DataPoint.
|
|
"""
|
|
try:
|
|
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
|
|
|
# Extract core fields with defaults if not present
|
|
core_properties = {
|
|
"id": str(properties.get("id", "")),
|
|
"name": str(properties.get("name", "")),
|
|
"type": str(properties.get("type", "")),
|
|
}
|
|
|
|
# Remove core fields from other properties
|
|
for key in core_properties:
|
|
properties.pop(key, None)
|
|
|
|
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
|
|
|
|
# Add timestamps for new node
|
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
fields = []
|
|
params = {}
|
|
for key, value in core_properties.items():
|
|
if value is not None:
|
|
param_name = f"param_{key}"
|
|
fields.append(f"{key}: ${param_name}")
|
|
params[param_name] = value
|
|
|
|
# Add timestamp fields
|
|
fields.extend(
|
|
["created_at: timestamp($created_at)", "updated_at: timestamp($updated_at)"]
|
|
)
|
|
params.update({"created_at": now, "updated_at": now})
|
|
|
|
merge_query = f"""
|
|
MERGE (n:Node {{id: $param_id}})
|
|
ON CREATE SET n += {{{", ".join(fields)}}}
|
|
"""
|
|
await self.query(merge_query, params)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add node: {e}")
|
|
raise
|
|
|
|
@record_graph_changes
|
|
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
|
"""
|
|
Add multiple nodes to the graph in a batch operation.
|
|
|
|
This method allows for the addition of multiple nodes in a single operation to enhance
|
|
performance. It processes a list of nodes and constructs the necessary query for
|
|
insertion. Errors encountered during the addition will be logged and raised.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- nodes (List[DataPoint]): A list of nodes to be added to the graph, each
|
|
represented as a DataPoint.
|
|
"""
|
|
if not nodes:
|
|
return
|
|
|
|
try:
|
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
|
|
# Prepare all nodes data
|
|
node_params = []
|
|
for node in nodes:
|
|
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
|
|
|
core_properties = {
|
|
"id": str(properties.get("id", "")),
|
|
"name": str(properties.get("name", "")),
|
|
"type": str(properties.get("type", "")),
|
|
}
|
|
|
|
# Remove core fields from other properties
|
|
for key in core_properties:
|
|
properties.pop(key, None)
|
|
|
|
node_params.append(
|
|
{
|
|
**core_properties,
|
|
"properties": json.dumps(properties, cls=JSONEncoder),
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
}
|
|
)
|
|
|
|
if node_params:
|
|
# Batch merge nodes
|
|
merge_query = """
|
|
UNWIND $nodes AS node
|
|
MERGE (n:Node {id: node.id})
|
|
ON CREATE SET
|
|
n.name = node.name,
|
|
n.type = node.type,
|
|
n.properties = node.properties,
|
|
n.created_at = timestamp(node.created_at),
|
|
n.updated_at = timestamp(node.updated_at)
|
|
ON MATCH SET
|
|
n.name = node.name,
|
|
n.type = node.type,
|
|
n.properties = node.properties,
|
|
n.updated_at = timestamp(node.updated_at)
|
|
"""
|
|
await self.query(merge_query, {"nodes": node_params})
|
|
logger.debug(f"Processed {len(node_params)} nodes in batch")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add nodes in batch: {e}")
|
|
raise
|
|
|
|
async def delete_node(self, node_id: str) -> None:
|
|
"""
|
|
Delete a node and its relationships.
|
|
|
|
This method removes a node identified by its ID along with all associated relationships.
|
|
It encapsulates the delete operation for simplicity in usage.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node to be deleted.
|
|
"""
|
|
query_str = "MATCH (n:Node) WHERE n.id = $id DETACH DELETE n"
|
|
await self.query(query_str, {"id": node_id})
|
|
|
|
async def delete_nodes(self, node_ids: List[str]) -> None:
|
|
"""
|
|
Delete multiple nodes at once.
|
|
|
|
This method facilitates the deletion of a list of nodes, identified by their IDs,
|
|
concurrently. It ensures efficiency by using a single query to detach deletes for all
|
|
nodes in the list.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of identifiers for the nodes to be deleted.
|
|
"""
|
|
query_str = "MATCH (n:Node) WHERE n.id IN $ids DETACH DELETE n"
|
|
await self.query(query_str, {"ids": node_ids})
|
|
|
|
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Extract a node by its ID.
|
|
|
|
This method retrieves a node's data by its identifier and returns it as a dictionary. If
|
|
the node is not found or an error occurs, it returns None.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node to be extracted.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- Optional[Dict[str, Any]]: A dictionary of the node's properties if found,
|
|
otherwise None.
|
|
"""
|
|
query_str = """
|
|
MATCH (n:Node)
|
|
WHERE n.id = $id
|
|
RETURN {
|
|
id: n.id,
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
}
|
|
"""
|
|
try:
|
|
result = await self.query(query_str, {"id": node_id})
|
|
if result and result[0]:
|
|
node_data = self._parse_node(result[0][0])
|
|
return node_data
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to extract node {node_id}: {e}")
|
|
return None
|
|
|
|
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Extract multiple nodes by their IDs.
|
|
|
|
This method retrieves a list of nodes identified by their IDs and returns their data as
|
|
a list of dictionaries. It handles possible retrieval errors internally and will return
|
|
an empty list if no nodes are found.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of identifiers for the nodes to be extracted.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Dict[str, Any]]: A list of dictionaries containing the properties of the
|
|
extracted nodes.
|
|
"""
|
|
query_str = """
|
|
MATCH (n:Node)
|
|
WHERE n.id IN $node_ids
|
|
RETURN {
|
|
id: n.id,
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
}
|
|
"""
|
|
try:
|
|
results = await self.query(query_str, {"node_ids": node_ids})
|
|
# Parse each node using the same helper function
|
|
nodes = [self._parse_node(row[0]) for row in results if row[0]]
|
|
return nodes
|
|
except Exception as e:
|
|
logger.error(f"Failed to extract nodes: {e}")
|
|
return []
|
|
|
|
# Edge Operations
|
|
|
|
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
|
"""
|
|
Check if an edge exists between nodes with the given relationship name.
|
|
|
|
This method verifies the existence of a directed edge defined by the relationship name
|
|
between two specified nodes. It returns a boolean value indicating presence or absence
|
|
of the edge.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- from_node (str): The identifier of the source node.
|
|
- to_node (str): The identifier of the target node.
|
|
- edge_label (str): The label of the edge representing the relationship name.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- bool: True if the edge exists, False otherwise.
|
|
"""
|
|
query_str = """
|
|
MATCH (from:Node)-[r:EDGE]->(to:Node)
|
|
WHERE from.id = $from_id AND to.id = $to_id AND r.relationship_name = $edge_label
|
|
RETURN COUNT(r) > 0
|
|
"""
|
|
result = await self.query(
|
|
query_str, {"from_id": from_node, "to_id": to_node, "edge_label": edge_label}
|
|
)
|
|
return result[0][0] if result else False
|
|
|
|
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
|
"""
|
|
Check if multiple edges exist in a batch operation.
|
|
|
|
This method checks for the presence of specified edges in the database and returns a
|
|
list of edges that exist. It is beneficial for efficiency in checking multiple edges
|
|
simultaneously.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges (List[Tuple[str, str, str]]): A list of edges where each edge is represented
|
|
as a tuple of (from_node, to_node, edge_label).
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Tuple[str, str, str]]: A list of tuples representing the existing edges from
|
|
the provided list.
|
|
"""
|
|
if not edges:
|
|
return []
|
|
|
|
try:
|
|
# Transform edges into format needed for batch query
|
|
edge_params = [
|
|
{
|
|
"from_id": str(from_node), # Ensure string type
|
|
"to_id": str(to_node), # Ensure string type
|
|
"relationship_name": str(edge_label), # Ensure string type
|
|
}
|
|
for from_node, to_node, edge_label in edges
|
|
]
|
|
|
|
# Batch check query with direct string comparison
|
|
query = """
|
|
UNWIND $edges AS edge
|
|
MATCH (from:Node)-[r:EDGE]->(to:Node)
|
|
WHERE from.id = edge.from_id
|
|
AND to.id = edge.to_id
|
|
AND r.relationship_name = edge.relationship_name
|
|
RETURN from.id, to.id, r.relationship_name
|
|
"""
|
|
|
|
results = await self.query(query, {"edges": edge_params})
|
|
|
|
# Convert results back to tuples and ensure string types
|
|
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
|
|
|
|
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
|
|
return existing_edges
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to check edges in batch: {e}")
|
|
return []
|
|
|
|
async def add_edge(
|
|
self,
|
|
from_node: str,
|
|
to_node: str,
|
|
relationship_name: str,
|
|
edge_properties: Dict[str, Any] = {},
|
|
) -> None:
|
|
"""
|
|
Add an edge between two nodes.
|
|
|
|
This method constructs and executes a query to create a directed edge between two
|
|
specified nodes with certain properties. It will raise an error if the addition fails
|
|
during execution.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- from_node (str): The identifier of the source node from which the edge originates.
|
|
- to_node (str): The identifier of the target node to which the edge points.
|
|
- relationship_name (str): The label of the edge to be created, representing the
|
|
relationship name.
|
|
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
|
|
(default {})
|
|
"""
|
|
try:
|
|
query, params = self._edge_query_and_params(
|
|
from_node, to_node, relationship_name, edge_properties
|
|
)
|
|
await self.query(query, params)
|
|
except Exception as e:
|
|
logger.error(f"Failed to add edge: {e}")
|
|
raise
|
|
|
|
@record_graph_changes
|
|
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
|
"""
|
|
Add multiple edges in a batch operation.
|
|
|
|
This method enables efficient insertion of multiple edges at once by processing a list
|
|
of edge details. It improves performance for batch operations compared to adding edges
|
|
individually. Errors during execution are logged and raised as necessary.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges (List[Tuple[str, str, str, Dict[str, Any]]]): A list of edges represented as
|
|
tuples of (from_node, to_node, relationship_name, edge_properties).
|
|
"""
|
|
if not edges:
|
|
return
|
|
|
|
try:
|
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
|
|
edge_params = [
|
|
{
|
|
"from_id": from_node,
|
|
"to_id": to_node,
|
|
"relationship_name": relationship_name,
|
|
"properties": json.dumps(properties, cls=JSONEncoder),
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
}
|
|
for from_node, to_node, relationship_name, properties in edges
|
|
]
|
|
|
|
query = """
|
|
UNWIND $edges AS edge
|
|
MATCH (from:Node), (to:Node)
|
|
WHERE from.id = edge.from_id AND to.id = edge.to_id
|
|
MERGE (from)-[r:EDGE {
|
|
relationship_name: edge.relationship_name
|
|
}]->(to)
|
|
ON CREATE SET
|
|
r.created_at = timestamp(edge.created_at),
|
|
r.updated_at = timestamp(edge.updated_at),
|
|
r.properties = edge.properties
|
|
ON MATCH SET
|
|
r.updated_at = timestamp(edge.updated_at),
|
|
r.properties = edge.properties
|
|
"""
|
|
|
|
await self.query(query, {"edges": edge_params})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add edges in batch: {e}")
|
|
raise
|
|
|
|
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
|
|
"""
|
|
Get all edges connected to a node.
|
|
|
|
This method retrieves all edges that are linked to a specified node and returns them in
|
|
a structured format. If an error occurs or no edges exist, an empty list is returned.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node for which to retrieve edges.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
|
|
tuple contains (source_node, relationship_name, target_node), with source_node and
|
|
target_node as dictionaries of node properties.
|
|
"""
|
|
query_str = """
|
|
MATCH (n:Node)-[r]-(m:Node)
|
|
WHERE n.id = $node_id
|
|
RETURN {
|
|
id: n.id,
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
},
|
|
r.relationship_name,
|
|
{
|
|
id: m.id,
|
|
name: m.name,
|
|
type: m.type,
|
|
properties: m.properties
|
|
}
|
|
"""
|
|
try:
|
|
results = await self.query(query_str, {"node_id": node_id})
|
|
edges = []
|
|
for row in results:
|
|
if row and len(row) == 3:
|
|
source_node = self._parse_node_properties(row[0])
|
|
target_node = self._parse_node_properties(row[2])
|
|
edges.append((source_node, row[1], target_node))
|
|
return edges
|
|
except Exception as e:
|
|
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
|
return []
|
|
|
|
# Neighbor Operations
|
|
|
|
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all neighboring nodes.
|
|
|
|
This method simply calls the get_neighbours method for API compatibility and retrieves
|
|
connected nodes neighboring the specified node. It returns a list of neighbor nodes'
|
|
properties as dictionaries.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node for which to find neighbors.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Dict[str, Any]]: A list of dictionaries representing neighboring nodes'
|
|
properties.
|
|
"""
|
|
return await self.get_neighbours(node_id)
|
|
|
|
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get a single node by ID.
|
|
|
|
This method retrieves the properties of a node identified by its ID and returns them as
|
|
a dictionary. If the node does not exist, None is returned.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node to retrieve.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- Optional[Dict[str, Any]]: A dictionary containing the properties of the node if
|
|
found, otherwise None.
|
|
"""
|
|
query_str = """
|
|
MATCH (n:Node)
|
|
WHERE n.id = $id
|
|
RETURN {
|
|
id: n.id,
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
}
|
|
"""
|
|
try:
|
|
result = await self.query(query_str, {"id": node_id})
|
|
if result and result[0]:
|
|
return self._parse_node(result[0][0])
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to get node {node_id}: {e}")
|
|
return None
|
|
|
|
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get multiple nodes by their IDs.
|
|
|
|
This method retrieves properties for multiple nodes identified by their IDs and returns
|
|
them as a list of dictionaries. An empty list is returned if no nodes are found or an
|
|
error occurs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of identifiers for the nodes to be retrieved.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Dict[str, Any]]: A list of dictionaries containing properties of each
|
|
retrieved node.
|
|
"""
|
|
query_str = """
|
|
MATCH (n:Node)
|
|
WHERE n.id IN $node_ids
|
|
RETURN {
|
|
id: n.id,
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
}
|
|
"""
|
|
try:
|
|
results = await self.query(query_str, {"node_ids": node_ids})
|
|
return [self._parse_node(row[0]) for row in results if row[0]]
|
|
except Exception as e:
|
|
logger.error(f"Failed to get nodes: {e}")
|
|
return []
|
|
|
|
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all neighbouring nodes.
|
|
|
|
This method retrieves all neighboring nodes connected to a specified node and returns
|
|
them as a list of dictionaries. It may return an empty list if no neighbors exist or an
|
|
error occurs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node for which to find neighbors.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Dict[str, Any]]: A list of dictionaries representing neighboring nodes'
|
|
properties.
|
|
"""
|
|
query_str = """
|
|
MATCH (n)-[r]-(m)
|
|
WHERE n.id = $id
|
|
RETURN DISTINCT properties(m)
|
|
"""
|
|
try:
|
|
result = await self.query(query_str, {"id": node_id})
|
|
return [row[0] for row in result] if result else []
|
|
except Exception as e:
|
|
logger.error(f"Failed to get neighbours for node {node_id}: {e}")
|
|
return []
|
|
|
|
async def get_predecessors(
|
|
self, node_id: Union[str, UUID], edge_label: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all predecessor nodes.
|
|
|
|
This method retrieves all nodes that are predecessors of the specified node. If an edge
|
|
label is provided, it filters the results accordingly. It returns a list of dictionaries
|
|
containing properties of these predecessor nodes.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (Union[str, UUID]): The identifier of the specified node.
|
|
- edge_label (Optional[str]): An optional label to filter the edges by relationship
|
|
name. (default None)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Dict[str, Any]]: A list of dictionaries representing all predecessor nodes'
|
|
properties.
|
|
"""
|
|
try:
|
|
if edge_label:
|
|
query_str = """
|
|
MATCH (n)<-[r:EDGE]-(m)
|
|
WHERE n.id = $id AND r.relationship_name = $edge_label
|
|
RETURN properties(m)
|
|
"""
|
|
params = {"id": str(node_id), "edge_label": edge_label}
|
|
else:
|
|
query_str = """
|
|
MATCH (n)<-[r:EDGE]-(m)
|
|
WHERE n.id = $id
|
|
RETURN properties(m)
|
|
"""
|
|
params = {"id": str(node_id)}
|
|
result = await self.query(query_str, params)
|
|
return [row[0] for row in result] if result else []
|
|
except Exception as e:
|
|
logger.error(f"Failed to get predecessors for node {node_id}: {e}")
|
|
return []
|
|
|
|
async def get_successors(
|
|
self, node_id: Union[str, UUID], edge_label: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all successor nodes.
|
|
|
|
This method retrieves all nodes that are successors of the specified node. An edge label
|
|
can be provided to filter the results. It returns a list of dictionaries detailing these
|
|
successor nodes' properties.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (Union[str, UUID]): The identifier of the specified node.
|
|
- edge_label (Optional[str]): An optional label to filter the edges by relationship
|
|
name. (default None)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Dict[str, Any]]: A list of dictionaries representing all successor nodes'
|
|
properties.
|
|
"""
|
|
try:
|
|
if edge_label:
|
|
query_str = """
|
|
MATCH (n)-[r:EDGE]->(m)
|
|
WHERE n.id = $id AND r.relationship_name = $edge_label
|
|
RETURN properties(m)
|
|
"""
|
|
params = {"id": str(node_id), "edge_label": edge_label}
|
|
else:
|
|
query_str = """
|
|
MATCH (n)-[r:EDGE]->(m)
|
|
WHERE n.id = $id
|
|
RETURN properties(m)
|
|
"""
|
|
params = {"id": str(node_id)}
|
|
result = await self.query(query_str, params)
|
|
return [row[0] for row in result] if result else []
|
|
except Exception as e:
|
|
logger.error(f"Failed to get successors for node {node_id}: {e}")
|
|
return []
|
|
|
|
async def get_connections(
|
|
self, node_id: str
|
|
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
|
|
"""
|
|
Get all nodes connected to a given node.
|
|
|
|
This method retrieves all nodes directly connected to a specified node along with the
|
|
relationships between them, returning structured data in a list of tuples. Each tuple
|
|
contains source and target node properties along with the relationship information.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): The identifier of the node for which to retrieve connections.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]}: A list of tuples
|
|
containing (source_node, relationship_name, target_node) with dictionaries for
|
|
source_node and target_node properties.
|
|
"""
|
|
query_str = """
|
|
MATCH (n:Node)-[r:EDGE]-(m:Node)
|
|
WHERE n.id = $node_id
|
|
RETURN {
|
|
id: n.id,
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
},
|
|
{
|
|
relationship_name: r.relationship_name,
|
|
properties: r.properties
|
|
},
|
|
{
|
|
id: m.id,
|
|
name: m.name,
|
|
type: m.type,
|
|
properties: m.properties
|
|
}
|
|
"""
|
|
try:
|
|
results = await self.query(query_str, {"node_id": node_id})
|
|
edges = []
|
|
for row in results:
|
|
if row and len(row) == 3:
|
|
processed_rows = []
|
|
for i, item in enumerate(row):
|
|
if isinstance(item, dict):
|
|
if "properties" in item and item["properties"]:
|
|
try:
|
|
props = json.loads(item["properties"])
|
|
item.update(props)
|
|
del item["properties"]
|
|
except json.JSONDecodeError:
|
|
logger.warning(
|
|
f"Failed to parse JSON properties for node/edge {i}"
|
|
)
|
|
processed_rows.append(item)
|
|
edges.append(tuple(processed_rows))
|
|
return edges if edges else [] # Always return a list, even if empty
|
|
except Exception as e:
|
|
logger.error(f"Failed to get connections for node {node_id}: {e}")
|
|
return [] # Return empty list on error
|
|
|
|
async def remove_connection_to_predecessors_of(
|
|
self, node_ids: List[str], edge_label: str
|
|
) -> None:
|
|
"""
|
|
Remove all incoming edges of specified type for given nodes.
|
|
|
|
This method disconnects predecessor relationships of a specific type for the specified
|
|
nodes, managing edges in a single operation effectively.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of identifiers for the nodes whose relationships to
|
|
be removed.
|
|
- edge_label (str): The label of the edge to be removed.
|
|
"""
|
|
query_str = """
|
|
MATCH (n)<-[r:EDGE]-(m)
|
|
WHERE n.id IN $node_ids AND r.relationship_name = $edge_label
|
|
DELETE r
|
|
"""
|
|
await self.query(query_str, {"node_ids": node_ids, "edge_label": edge_label})
|
|
|
|
async def remove_connection_to_successors_of(
|
|
self, node_ids: List[str], edge_label: str
|
|
) -> None:
|
|
"""
|
|
Remove all outgoing edges of specified type for given nodes.
|
|
|
|
This method disconnects successor relationships of a specified type for the specified
|
|
nodes in a single efficient operation.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of identifiers for the nodes whose relationships to
|
|
be removed.
|
|
- edge_label (str): The label of the edge to be removed.
|
|
"""
|
|
query_str = """
|
|
MATCH (n)-[r:EDGE]->(m)
|
|
WHERE n.id IN $node_ids AND r.relationship_name = $edge_label
|
|
DELETE r
|
|
"""
|
|
await self.query(query_str, {"node_ids": node_ids, "edge_label": edge_label})
|
|
|
|
# Graph-wide Operations
|
|
|
|
async def get_graph_data(
|
|
self,
|
|
) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, str, str, Dict[str, Any]]]]:
|
|
"""
|
|
Get all nodes and edges in the graph.
|
|
|
|
This method fetches the entire graph's structure, including all nodes and their
|
|
properties as well as relationships and their details, returning them in a structured
|
|
format. Errors during query execution will result in raised exceptions.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- Tuple[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, str, str, Dict[str, Any]]]]:
|
|
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
|
|
tuples of (source_id, target_id, relationship_name, properties).
|
|
"""
|
|
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
nodes_query = """
|
|
MATCH (n:Node)
|
|
RETURN n.id, {
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
}
|
|
"""
|
|
nodes = await self.query(nodes_query)
|
|
formatted_nodes = []
|
|
for n in nodes:
|
|
if n[0]:
|
|
node_id = str(n[0])
|
|
props = n[1]
|
|
if props.get("properties"):
|
|
try:
|
|
additional_props = json.loads(props["properties"])
|
|
props.update(additional_props)
|
|
del props["properties"]
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse properties JSON for node {node_id}")
|
|
formatted_nodes.append((node_id, props))
|
|
if not formatted_nodes:
|
|
logger.warning("No nodes found in the database")
|
|
return [], []
|
|
|
|
edges_query = """
|
|
MATCH (n:Node)-[r]->(m:Node)
|
|
RETURN n.id, m.id, r.relationship_name, r.properties
|
|
"""
|
|
edges = await self.query(edges_query)
|
|
formatted_edges = []
|
|
for e in edges:
|
|
if e and len(e) >= 3:
|
|
source_id = str(e[0])
|
|
target_id = str(e[1])
|
|
rel_type = str(e[2])
|
|
props = {}
|
|
if len(e) > 3 and e[3]:
|
|
try:
|
|
props = json.loads(e[3])
|
|
except (json.JSONDecodeError, TypeError):
|
|
logger.warning(
|
|
f"Failed to parse edge properties for {source_id}->{target_id}"
|
|
)
|
|
formatted_edges.append((source_id, target_id, rel_type, props))
|
|
|
|
if formatted_nodes and not formatted_edges:
|
|
logger.debug("No edges found, creating self-referential edges for nodes")
|
|
for node_id, _ in formatted_nodes:
|
|
formatted_edges.append(
|
|
(
|
|
node_id,
|
|
node_id,
|
|
"SELF",
|
|
{
|
|
"relationship_name": "SELF",
|
|
"relationship_type": "SELF",
|
|
"vector_distance": 0.0,
|
|
},
|
|
)
|
|
)
|
|
|
|
retrieval_time = time.time() - start_time
|
|
logger.info(
|
|
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
|
)
|
|
return formatted_nodes, formatted_edges
|
|
except Exception as e:
|
|
logger.error(f"Failed to get graph data: {e}")
|
|
raise
|
|
|
|
async def get_nodeset_subgraph(
|
|
self, node_type: Type[Any], node_name: List[str]
|
|
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
|
"""
|
|
Get subgraph for a set of nodes based on type and names.
|
|
|
|
This method queries for nodes of a specific type and their corresponding neighbors,
|
|
returning both nodes and edges connecting them. It's useful for analyzing a targeted
|
|
subset of the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_type (Type[Any]): Type of nodes to retrieve as specified by the user.
|
|
- node_name (List[str]): List of names corresponding to the nodes to be retrieved.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]}: A tuple
|
|
containing a list of nodes and a list of edges related to those nodes.
|
|
"""
|
|
label = node_type.__name__
|
|
primary_query = """
|
|
UNWIND $names AS wantedName
|
|
MATCH (n:Node)
|
|
WHERE n.type = $label AND n.name = wantedName
|
|
RETURN DISTINCT n.id
|
|
"""
|
|
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
|
|
primary_ids = [row[0] for row in primary_rows]
|
|
if not primary_ids:
|
|
return [], []
|
|
|
|
neighbor_query = """
|
|
MATCH (n:Node)-[:EDGE]-(nbr:Node)
|
|
WHERE n.id IN $ids
|
|
RETURN DISTINCT nbr.id
|
|
"""
|
|
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
|
|
neighbor_ids = [row[0] for row in nbr_rows]
|
|
|
|
all_ids = list({*primary_ids, *neighbor_ids})
|
|
|
|
nodes_query = """
|
|
MATCH (n:Node)
|
|
WHERE n.id IN $ids
|
|
RETURN n.id, n.name, n.type, n.properties
|
|
"""
|
|
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
|
nodes: List[Tuple[str, dict]] = []
|
|
for node_id, name, typ, props in node_rows:
|
|
data = {"id": node_id, "name": name, "type": typ}
|
|
if props:
|
|
try:
|
|
data.update(json.loads(props))
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse JSON props for node {node_id}")
|
|
nodes.append((node_id, data))
|
|
|
|
edges_query = """
|
|
MATCH (a:Node)-[r:EDGE]-(b:Node)
|
|
WHERE a.id IN $ids AND b.id IN $ids
|
|
RETURN a.id, b.id, r.relationship_name, r.properties
|
|
"""
|
|
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
|
edges: List[Tuple[str, str, str, dict]] = []
|
|
for from_id, to_id, rel_type, props in edge_rows:
|
|
data = {}
|
|
if props:
|
|
try:
|
|
data = json.loads(props)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
|
|
|
edges.append((from_id, to_id, rel_type, data))
|
|
|
|
return nodes, edges
|
|
|
|
async def get_filtered_graph_data(
|
|
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
|
):
|
|
"""
|
|
Get filtered nodes and relationships based on attributes.
|
|
|
|
This method accepts attribute filters and retrieves nodes and relationships that match
|
|
the specified conditions. It allows complex filtering across node properties and edge
|
|
attributes.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- attribute_filters (List[Dict[str, List[Union[str, int]]]]): A list of dictionaries
|
|
specifying attributes and their corresponding values for filtering nodes and
|
|
edges.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
A tuple containing a list of filtered node properties and a list of filtered edge
|
|
properties.
|
|
"""
|
|
where_clauses = []
|
|
params = {}
|
|
|
|
for i, filter_dict in enumerate(attribute_filters):
|
|
for attr, values in filter_dict.items():
|
|
param_name = f"values_{i}_{attr}"
|
|
where_clauses.append(f"n.{attr} IN ${param_name}")
|
|
params[param_name] = values
|
|
|
|
where_clause = " AND ".join(where_clauses)
|
|
nodes_query = f"""
|
|
MATCH (n:Node)
|
|
WHERE {where_clause}
|
|
RETURN n.id, {{
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
}}
|
|
"""
|
|
edges_query = f"""
|
|
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
|
|
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}
|
|
RETURN n1.id, n2.id, r.relationship_name, r.properties
|
|
"""
|
|
nodes, edges = await asyncio.gather(
|
|
self.query(nodes_query, params), self.query(edges_query, params)
|
|
)
|
|
formatted_nodes = []
|
|
for n in nodes:
|
|
if n[0]:
|
|
node_id = str(n[0])
|
|
props = n[1]
|
|
if props.get("properties"):
|
|
try:
|
|
additional_props = json.loads(props["properties"])
|
|
props.update(additional_props)
|
|
del props["properties"]
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse properties JSON for node {node_id}")
|
|
formatted_nodes.append((node_id, props))
|
|
if not formatted_nodes:
|
|
logger.warning("No nodes found in the database")
|
|
return [], []
|
|
|
|
formatted_edges = []
|
|
for e in edges:
|
|
if e and len(e) >= 3:
|
|
source_id = str(e[0])
|
|
target_id = str(e[1])
|
|
rel_type = str(e[2])
|
|
props = {}
|
|
if len(e) > 3 and e[3]:
|
|
try:
|
|
props = json.loads(e[3])
|
|
except (json.JSONDecodeError, TypeError):
|
|
logger.warning(
|
|
f"Failed to parse edge properties for {source_id}->{target_id}"
|
|
)
|
|
formatted_edges.append((source_id, target_id, rel_type, props))
|
|
return formatted_nodes, formatted_edges
|
|
|
|
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
|
"""
|
|
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
|
and only edges where one endpoint matches those IDs.
|
|
|
|
Returns:
|
|
nodes: List[dict] -> Each dict includes "id" and all node properties
|
|
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
|
|
"""
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
if not target_ids:
|
|
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
|
return [], []
|
|
|
|
if not all(isinstance(x, str) for x in target_ids):
|
|
raise CogneeValidationError("target_ids must be a list of strings")
|
|
|
|
query = """
|
|
MATCH (n:Node)-[r]->(m:Node)
|
|
WHERE n.id IN $target_ids OR m.id IN $target_ids
|
|
RETURN n.id, {
|
|
name: n.name,
|
|
type: n.type,
|
|
properties: n.properties
|
|
}, m.id, {
|
|
name: m.name,
|
|
type: m.type,
|
|
properties: m.properties
|
|
}, r.relationship_name, r.properties
|
|
"""
|
|
|
|
result = await self.query(query, {"target_ids": target_ids})
|
|
|
|
if not result:
|
|
logger.info("No data returned for the supplied IDs")
|
|
return [], []
|
|
|
|
nodes_dict = {}
|
|
edges = []
|
|
|
|
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
|
|
if n_props.get("properties"):
|
|
try:
|
|
additional_props = json.loads(n_props["properties"])
|
|
n_props.update(additional_props)
|
|
del n_props["properties"]
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse properties JSON for node {n_id}")
|
|
|
|
if m_props.get("properties"):
|
|
try:
|
|
additional_props = json.loads(m_props["properties"])
|
|
m_props.update(additional_props)
|
|
del m_props["properties"]
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse properties JSON for node {m_id}")
|
|
|
|
nodes_dict[n_id] = (n_id, n_props)
|
|
nodes_dict[m_id] = (m_id, m_props)
|
|
|
|
edge_props = {}
|
|
if r_props_raw:
|
|
try:
|
|
edge_props = json.loads(r_props_raw)
|
|
except (json.JSONDecodeError, TypeError):
|
|
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
|
|
|
|
source_id = edge_props.get("source_node_id", n_id)
|
|
target_id = edge_props.get("target_node_id", m_id)
|
|
edges.append((source_id, target_id, r_type, edge_props))
|
|
|
|
retrieval_time = time.time() - start_time
|
|
logger.info(
|
|
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
|
)
|
|
|
|
return list(nodes_dict.values()), edges
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
|
raise
|
|
|
|
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
|
"""
|
|
Get metrics on graph structure and connectivity.
|
|
|
|
This method computes various metrics around the graph, such as node and edge counts,
|
|
mean degree, and connected component sizes. Optionally, it can include additional
|
|
metrics based on user request.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- include_optional: A boolean flag indicating whether to include optional metrics in
|
|
the output. (default False)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- Dict[str, Any]: A dictionary containing various metrics related to the graph.
|
|
"""
|
|
|
|
try:
|
|
# Get basic graph data
|
|
nodes, edges = await self.get_model_independent_graph_data()
|
|
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
|
|
num_edges = len(edges[0]["elements"]) if edges else 0
|
|
|
|
# Calculate mandatory metrics
|
|
mandatory_metrics = {
|
|
"num_nodes": num_nodes,
|
|
"num_edges": num_edges,
|
|
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
|
"edge_density": num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
|
|
"num_connected_components": await self._get_num_connected_components(),
|
|
"sizes_of_connected_components": await self._get_size_of_connected_components(),
|
|
}
|
|
|
|
if include_optional:
|
|
# Calculate optional metrics
|
|
shortest_path_lengths = await self._get_shortest_path_lengths()
|
|
optional_metrics = {
|
|
"num_selfloops": await self._count_self_loops(),
|
|
"diameter": max(shortest_path_lengths) if shortest_path_lengths else -1,
|
|
"avg_shortest_path_length": sum(shortest_path_lengths)
|
|
/ len(shortest_path_lengths)
|
|
if shortest_path_lengths
|
|
else -1,
|
|
"avg_clustering": await self._get_avg_clustering(),
|
|
}
|
|
else:
|
|
optional_metrics = {
|
|
"num_selfloops": -1,
|
|
"diameter": -1,
|
|
"avg_shortest_path_length": -1,
|
|
"avg_clustering": -1,
|
|
}
|
|
|
|
return {**mandatory_metrics, **optional_metrics}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get graph metrics: {e}")
|
|
return {
|
|
"num_nodes": 0,
|
|
"num_edges": 0,
|
|
"mean_degree": 0,
|
|
"edge_density": 0,
|
|
"num_connected_components": 0,
|
|
"sizes_of_connected_components": [],
|
|
"num_selfloops": -1,
|
|
"diameter": -1,
|
|
"avg_shortest_path_length": -1,
|
|
"avg_clustering": -1,
|
|
}
|
|
|
|
async def _get_num_connected_components(self) -> int:
|
|
"""Get the number of connected components in the graph."""
|
|
query = """
|
|
MATCH (n:Node)
|
|
WITH n.id AS node_id
|
|
MATCH path = (n)-[:EDGE*1..3]-(m)
|
|
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
|
|
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
|
|
RETURN SIZE(components) AS num_components
|
|
"""
|
|
result = await self.query(query)
|
|
return result[0][0] if result else 0
|
|
|
|
async def _get_size_of_connected_components(self) -> List[int]:
|
|
"""Get the sizes of all connected components in the graph."""
|
|
query = """
|
|
MATCH (n:Node)
|
|
WITH n.id AS node_id
|
|
MATCH path = (n)-[:EDGE*1..3]-(m)
|
|
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
|
|
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
|
|
UNWIND components AS component
|
|
RETURN SIZE(component) AS component_size
|
|
"""
|
|
result = await self.query(query)
|
|
return [row[0] for row in result] if result else []
|
|
|
|
async def _get_shortest_path_lengths(self) -> List[int]:
|
|
"""Get the lengths of shortest paths between all pairs of nodes."""
|
|
query = """
|
|
MATCH (n:Node), (m:Node)
|
|
WHERE n.id < m.id
|
|
MATCH path = (n)-[:EDGE*]-(m)
|
|
RETURN MIN(LENGTH(path)) AS length
|
|
"""
|
|
result = await self.query(query)
|
|
return [row[0] for row in result if row[0] is not None] if result else []
|
|
|
|
async def _count_self_loops(self) -> int:
|
|
"""Count the number of self-loops in the graph."""
|
|
query = """
|
|
MATCH (n:Node)-[r:EDGE]->(n)
|
|
RETURN COUNT(r) AS count
|
|
"""
|
|
result = await self.query(query)
|
|
return result[0][0] if result else 0
|
|
|
|
async def _get_avg_clustering(self) -> float:
|
|
"""Calculate the average clustering coefficient of the graph."""
|
|
query = """
|
|
MATCH (n:Node)-[:EDGE]-(neighbor)
|
|
WITH n, COUNT(DISTINCT neighbor) as degree
|
|
MATCH (n)-[:EDGE]-(n1)-[:EDGE]-(n2)-[:EDGE]-(n)
|
|
WHERE n1 <> n2
|
|
RETURN AVG(CASE WHEN degree <= 1 THEN 0 ELSE COUNT(DISTINCT n2) / (degree * (degree-1)) END) AS avg_clustering
|
|
"""
|
|
result = await self.query(query)
|
|
return result[0][0] if result and result[0][0] is not None else -1
|
|
|
|
async def get_disconnected_nodes(self) -> List[str]:
|
|
"""
|
|
Get nodes that are not connected to any other node.
|
|
|
|
This method retrieves identifiers of nodes that lack any relationships in the graph,
|
|
indicating they are standalone. It will return an empty list if no disconnected nodes
|
|
exist.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- List[str]: A list of identifiers for disconnected nodes.
|
|
"""
|
|
query_str = """
|
|
MATCH (n:Node)
|
|
WHERE NOT EXISTS((n)-[]-())
|
|
RETURN n.id
|
|
"""
|
|
result = await self.query(query_str)
|
|
return [str(row[0]) for row in result]
|
|
|
|
# Graph Meta-Data Operations
|
|
|
|
async def get_model_independent_graph_data(self) -> Dict[str, List[str]]:
|
|
"""
|
|
Get graph data independent of any specific data model.
|
|
|
|
This method returns a representation of the graph that includes distinct node labels and
|
|
relationship types, making it easier to analyze the graph's structure without tying it
|
|
to a specific implementation.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- Dict[str, List[str]]: A dictionary summarizing the node labels and relationship
|
|
types present in the graph.
|
|
"""
|
|
node_labels = await self.query("MATCH (n:Node) RETURN DISTINCT labels(n)")
|
|
rel_types = await self.query("MATCH ()-[r:EDGE]->() RETURN DISTINCT r.relationship_name")
|
|
return {
|
|
"node_labels": [label[0] for label in node_labels],
|
|
"relationship_types": [rel[0] for rel in rel_types],
|
|
}
|
|
|
|
async def delete_graph(self) -> None:
|
|
"""
|
|
Delete all data from the graph database.
|
|
|
|
This method deletes all nodes and relationships from the graph database.
|
|
It raises exceptions for failures occurring during deletion processes.
|
|
"""
|
|
try:
|
|
if self.connection:
|
|
self.connection.close()
|
|
self.connection = None
|
|
if self.db:
|
|
self.db.close()
|
|
self.db = None
|
|
|
|
db_dir = os.path.dirname(self.db_path)
|
|
db_name = os.path.basename(self.db_path)
|
|
file_storage = get_file_storage(db_dir)
|
|
|
|
if await file_storage.is_file(db_name):
|
|
await file_storage.remove(db_name)
|
|
await file_storage.remove(f"{db_name}.lock")
|
|
else:
|
|
await file_storage.remove_all(db_name)
|
|
|
|
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete graph data: {e}")
|
|
raise
|
|
|
|
async def get_document_subgraph(self, data_id: str):
|
|
"""
|
|
Get all nodes that should be deleted when removing a document.
|
|
|
|
This method constructs a complex query that identifies all nodes related to a specified
|
|
document and returns a dictionary of these nodes. Ensures thorough checks for orphaned
|
|
entities and inaccurate relationships that should be removed alongside the document.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data_id (str): The identifier for the document to query against.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
A dictionary containing details of the document and associated nodes that need to be
|
|
deleted, or None if no related nodes are found.
|
|
"""
|
|
query = """
|
|
MATCH (doc:Node)
|
|
WHERE (doc.type = 'TextDocument' OR doc.type = 'PdfDocument' OR doc.type = 'AudioDocument' OR doc.type = 'ImageDocument' OR doc.type = 'UnstructuredDocument') AND doc.id = $data_id
|
|
|
|
OPTIONAL MATCH (doc)<-[e1:EDGE]-(chunk:Node)
|
|
WHERE e1.relationship_name = 'is_part_of' AND chunk.type = 'DocumentChunk'
|
|
|
|
OPTIONAL MATCH (chunk)-[e2:EDGE]->(entity:Node)
|
|
WHERE e2.relationship_name = 'contains' AND entity.type = 'Entity'
|
|
AND NOT EXISTS {
|
|
MATCH (entity)<-[e3:EDGE]-(otherChunk:Node)-[e4:EDGE]->(otherDoc:Node)
|
|
WHERE e3.relationship_name = 'contains'
|
|
AND e4.relationship_name = 'is_part_of'
|
|
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument' OR otherDoc.type = 'AudioDocument' OR otherDoc.type = 'ImageDocument' OR otherDoc.type = 'UnstructuredDocument')
|
|
AND otherDoc.id <> doc.id
|
|
}
|
|
|
|
OPTIONAL MATCH (chunk)<-[e5:EDGE]-(made_node:Node)
|
|
WHERE e5.relationship_name = 'made_from' AND made_node.type = 'TextSummary'
|
|
|
|
OPTIONAL MATCH (entity)-[e6:EDGE]->(type:Node)
|
|
WHERE e6.relationship_name = 'is_a' AND type.type = 'EntityType'
|
|
AND NOT EXISTS {
|
|
MATCH (type)<-[e7:EDGE]-(otherEntity:Node)-[e8:EDGE]-(otherChunk:Node)-[e9:EDGE]-(otherDoc:Node)
|
|
WHERE e7.relationship_name = 'is_a'
|
|
AND e8.relationship_name = 'contains'
|
|
AND e9.relationship_name = 'is_part_of'
|
|
AND otherEntity.type = 'Entity'
|
|
AND otherChunk.type = 'DocumentChunk'
|
|
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument' OR otherDoc.type = 'AudioDocument' OR otherDoc.type = 'ImageDocument' OR otherDoc.type = 'UnstructuredDocument')
|
|
AND otherDoc.id <> doc.id
|
|
}
|
|
|
|
RETURN
|
|
COLLECT(DISTINCT doc) as document,
|
|
COLLECT(DISTINCT chunk) as chunks,
|
|
COLLECT(DISTINCT entity) as orphan_entities,
|
|
COLLECT(DISTINCT made_node) as made_from_nodes,
|
|
COLLECT(DISTINCT type) as orphan_types
|
|
"""
|
|
result = await self.query(query, {"data_id": f"{data_id}"})
|
|
if not result or not result[0]:
|
|
return None
|
|
|
|
# Convert tuple to dictionary
|
|
return {
|
|
"document": result[0][0],
|
|
"chunks": result[0][1],
|
|
"orphan_entities": result[0][2],
|
|
"made_from_nodes": result[0][3],
|
|
"orphan_types": result[0][4],
|
|
}
|
|
|
|
async def get_degree_one_nodes(self, node_type: str):
|
|
"""
|
|
Get all nodes that have only one connection.
|
|
|
|
This method retrieves nodes which are connected to exactly one other node, identified by
|
|
their specific type. It raises a ValueError if the input type is invalid and processes
|
|
queries efficiently to return targeted results.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_type (str): The type of nodes to filter by, must be 'Entity' or 'EntityType'.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
A list of nodes that have only one connection, as identified by the specified type.
|
|
"""
|
|
if not node_type or node_type not in ["Entity", "EntityType"]:
|
|
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
|
|
|
query = f"""
|
|
MATCH (n:Node)
|
|
WHERE n.type = '{node_type}'
|
|
WITH n, COUNT {{ MATCH (n)--() }} as degree
|
|
WHERE degree = 1
|
|
RETURN n
|
|
"""
|
|
result = await self.query(query)
|
|
return [record[0] for record in result] if result else []
|
|
|
|
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
|
"""
|
|
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
|
|
Parameters:
|
|
-----------
|
|
- limit (int): The maximum number of interaction IDs to return.
|
|
Returns:
|
|
--------
|
|
- List[str]: A list of interaction IDs, sorted by created_at descending.
|
|
"""
|
|
|
|
query = """
|
|
MATCH (n)
|
|
WHERE n.type = 'CogneeUserInteraction'
|
|
RETURN n.id as id
|
|
ORDER BY n.created_at DESC
|
|
LIMIT $limit
|
|
"""
|
|
rows = await self.query(query, {"limit": limit})
|
|
|
|
id_list = [row[0] for row in rows]
|
|
return id_list
|
|
|
|
async def apply_feedback_weight(
|
|
self,
|
|
node_ids: List[str],
|
|
weight: float,
|
|
) -> None:
|
|
"""
|
|
Increment `feedback_weight` inside r.properties JSON for edges where
|
|
relationship_name = 'used_graph_element_to_answer'.
|
|
|
|
"""
|
|
# Step 1: fetch matching edges
|
|
query = """
|
|
MATCH (n:Node)-[r:EDGE]->()
|
|
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
|
RETURN r.properties, n.id
|
|
"""
|
|
results = await self.query(query, {"node_ids": node_ids})
|
|
|
|
# Step 2: update JSON client-side
|
|
updates = []
|
|
for props_json, source_id in results:
|
|
try:
|
|
props = json.loads(props_json) if props_json else {}
|
|
except json.JSONDecodeError:
|
|
props = {}
|
|
|
|
props["feedback_weight"] = props.get("feedback_weight", 0) + weight
|
|
updates.append((source_id, json.dumps(props)))
|
|
|
|
# Step 3: write back
|
|
for node_id, new_props in updates:
|
|
update_query = """
|
|
MATCH (n:Node)-[r:EDGE]->()
|
|
WHERE n.id = $node_id AND r.relationship_name = 'used_graph_element_to_answer'
|
|
SET r.properties = $props
|
|
"""
|
|
await self.query(update_query, {"node_id": node_id, "props": new_props})
|
|
|
|
async def collect_events(self, ids: List[str]) -> Any:
|
|
"""
|
|
Collect all Event-type nodes reachable within 1..2 hops
|
|
from the given node IDs.
|
|
|
|
Args:
|
|
graph_engine: Object exposing an async .query(str) -> Any
|
|
ids: List of node IDs (strings)
|
|
|
|
Returns:
|
|
List of events
|
|
"""
|
|
|
|
event_collection_cypher = """UNWIND [{quoted}] AS uid
|
|
MATCH (start {{id: uid}})
|
|
MATCH (start)-[*1..2]-(event)
|
|
WHERE event.type = 'Event'
|
|
WITH DISTINCT event
|
|
RETURN collect(event) AS events;
|
|
"""
|
|
|
|
query = event_collection_cypher.format(quoted=ids)
|
|
result = await self.query(query)
|
|
events = []
|
|
for node in result[0][0]:
|
|
props = json.loads(node["properties"])
|
|
|
|
event = {
|
|
"id": node["id"],
|
|
"name": node["name"],
|
|
"description": props.get("description"),
|
|
}
|
|
|
|
if props.get("location"):
|
|
event["location"] = props["location"]
|
|
|
|
events.append(event)
|
|
|
|
return [{"events": events}]
|
|
|
|
async def collect_time_ids(
|
|
self,
|
|
time_from: Optional[Timestamp] = None,
|
|
time_to: Optional[Timestamp] = None,
|
|
) -> str:
|
|
"""
|
|
Collect IDs of Timestamp nodes between time_from and time_to.
|
|
|
|
Args:
|
|
graph_engine: Object exposing an async .query(query, params) -> list[dict]
|
|
time_from: Lower bound int (inclusive), optional
|
|
time_to: Upper bound int (inclusive), optional
|
|
|
|
Returns:
|
|
A string of quoted IDs: "'id1', 'id2', 'id3'"
|
|
(ready for use in a Cypher UNWIND clause).
|
|
"""
|
|
|
|
ids: List[str] = []
|
|
|
|
if time_from and time_to:
|
|
time_from = date_to_int(time_from)
|
|
time_to = date_to_int(time_to)
|
|
|
|
cypher = f"""
|
|
MATCH (n:Node)
|
|
WHERE n.type = 'Timestamp'
|
|
// Extract time_at from the JSON string and cast to INT64
|
|
WITH n, json_extract(n.properties, '$.time_at') AS t_str
|
|
WITH n,
|
|
CASE
|
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
|
ELSE CAST(t_str AS INT64)
|
|
END AS t
|
|
WHERE t >= {time_from}
|
|
AND t <= {time_to}
|
|
RETURN n.id as id
|
|
"""
|
|
|
|
elif time_from:
|
|
time_from = date_to_int(time_from)
|
|
|
|
cypher = f"""
|
|
MATCH (n:Node)
|
|
WHERE n.type = 'Timestamp'
|
|
// Extract time_at from the JSON string and cast to INT64
|
|
WITH n, json_extract(n.properties, '$.time_at') AS t_str
|
|
WITH n,
|
|
CASE
|
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
|
ELSE CAST(t_str AS INT64)
|
|
END AS t
|
|
WHERE t >= {time_from}
|
|
RETURN n.id as id
|
|
"""
|
|
|
|
elif time_to:
|
|
time_to = date_to_int(time_to)
|
|
|
|
cypher = f"""
|
|
MATCH (n:Node)
|
|
WHERE n.type = 'Timestamp'
|
|
// Extract time_at from the JSON string and cast to INT64
|
|
WITH n, json_extract(n.properties, '$.time_at') AS t_str
|
|
WITH n,
|
|
CASE
|
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
|
ELSE CAST(t_str AS INT64)
|
|
END AS t
|
|
WHERE t <= {time_to}
|
|
RETURN n.id as id
|
|
"""
|
|
|
|
else:
|
|
return ids
|
|
|
|
time_nodes = await self.query(cypher)
|
|
time_ids_list = [item[0] for item in time_nodes]
|
|
|
|
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
|
|
|
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
|
|
"""
|
|
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
- offset (int): Number of triplets to skip before returning results.
|
|
- limit (int): Maximum number of triplets to return.
|
|
|
|
Returns:
|
|
--------
|
|
- list[dict[str, Any]]: A list of triplets, where each triplet is a dictionary
|
|
with keys: 'start_node', 'relationship_properties', 'end_node'.
|
|
|
|
Raises:
|
|
-------
|
|
- ValueError: If offset or limit are negative.
|
|
- Exception: Re-raises any exceptions from query execution.
|
|
"""
|
|
if offset < 0:
|
|
raise ValueError(f"Offset must be non-negative, got {offset}")
|
|
if limit < 0:
|
|
raise ValueError(f"Limit must be non-negative, got {limit}")
|
|
|
|
query = """
|
|
MATCH (start_node:Node)-[relationship:EDGE]->(end_node:Node)
|
|
RETURN {
|
|
start_node: {
|
|
id: start_node.id,
|
|
name: start_node.name,
|
|
type: start_node.type,
|
|
properties: start_node.properties
|
|
},
|
|
relationship_properties: {
|
|
relationship_name: relationship.relationship_name,
|
|
properties: relationship.properties
|
|
},
|
|
end_node: {
|
|
id: end_node.id,
|
|
name: end_node.name,
|
|
type: end_node.type,
|
|
properties: end_node.properties
|
|
}
|
|
} AS triplet
|
|
SKIP $offset LIMIT $limit
|
|
"""
|
|
|
|
try:
|
|
results = await self.query(query, {"offset": offset, "limit": limit})
|
|
except Exception as e:
|
|
logger.error(f"Failed to execute triplet query: {str(e)}")
|
|
logger.error(f"Query: {query}")
|
|
logger.error(f"Parameters: offset={offset}, limit={limit}")
|
|
raise
|
|
|
|
triplets = []
|
|
for idx, row in enumerate(results):
|
|
try:
|
|
if not row or len(row) == 0:
|
|
logger.warning(f"Skipping empty row at index {idx} in triplet batch")
|
|
continue
|
|
|
|
if not isinstance(row[0], dict):
|
|
logger.warning(
|
|
f"Skipping invalid row at index {idx}: expected dict, got {type(row[0])}"
|
|
)
|
|
continue
|
|
|
|
triplet = row[0]
|
|
|
|
if "start_node" not in triplet:
|
|
logger.warning(f"Skipping triplet at index {idx}: missing 'start_node' key")
|
|
continue
|
|
|
|
if not isinstance(triplet["start_node"], dict):
|
|
logger.warning(f"Skipping triplet at index {idx}: 'start_node' is not a dict")
|
|
continue
|
|
|
|
triplet["start_node"] = self._parse_node_properties(triplet["start_node"].copy())
|
|
|
|
if "relationship_properties" not in triplet:
|
|
logger.warning(
|
|
f"Skipping triplet at index {idx}: missing 'relationship_properties' key"
|
|
)
|
|
continue
|
|
|
|
if not isinstance(triplet["relationship_properties"], dict):
|
|
logger.warning(
|
|
f"Skipping triplet at index {idx}: 'relationship_properties' is not a dict"
|
|
)
|
|
continue
|
|
|
|
rel_props = triplet["relationship_properties"].copy()
|
|
relationship_name = rel_props.get("relationship_name") or ""
|
|
|
|
if rel_props.get("properties"):
|
|
try:
|
|
parsed_props = json.loads(rel_props["properties"])
|
|
if isinstance(parsed_props, dict):
|
|
rel_props.update(parsed_props)
|
|
del rel_props["properties"]
|
|
else:
|
|
logger.warning(
|
|
f"Parsed relationship properties is not a dict for triplet at index {idx}"
|
|
)
|
|
except (json.JSONDecodeError, TypeError) as e:
|
|
logger.warning(
|
|
f"Failed to parse relationship properties JSON for triplet at index {idx}: {e}"
|
|
)
|
|
|
|
rel_props["relationship_name"] = relationship_name
|
|
triplet["relationship_properties"] = rel_props
|
|
|
|
if "end_node" not in triplet:
|
|
logger.warning(f"Skipping triplet at index {idx}: missing 'end_node' key")
|
|
continue
|
|
|
|
if not isinstance(triplet["end_node"], dict):
|
|
logger.warning(f"Skipping triplet at index {idx}: 'end_node' is not a dict")
|
|
continue
|
|
|
|
triplet["end_node"] = self._parse_node_properties(triplet["end_node"].copy())
|
|
|
|
triplets.append(triplet)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing triplet at index {idx}: {e}", exc_info=True)
|
|
continue
|
|
|
|
return triplets
|