diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 4607b93ca..4255c062e 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -135,6 +135,56 @@ def create_graph_engine( graph_database_password=graph_database_password or None, ) + elif graph_database_provider == "neptune": + try: + from langchain_aws import NeptuneAnalyticsGraph + except ImportError: + raise ImportError( + "langchain_aws is not installed. Please install it with 'pip install langchain_aws'" + ) + + if not graph_database_url: + raise EnvironmentError("Missing Neptune endpoint.") + + from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL + + if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL): + raise ValueError(f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}") + + graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "") + + return NeptuneGraphDB( + graph_id=graph_identifier, + ) + + elif graph_database_provider == "neptune_analytics": + """ + Creates a graph DB from config + We want to use a hybrid (graph & vector) DB and we should update this + to make a single instance of the hybrid configuration (with embedder) + instead of creating the hybrid object twice. + """ + try: + from langchain_aws import NeptuneAnalyticsGraph + except ImportError: + raise ImportError( + "langchain_aws is not installed. Please install it with 'pip install langchain_aws'" + ) + + if not graph_database_url: + raise EnvironmentError("Missing Neptune endpoint.") + + from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL + + if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL): + raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}'") + + graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "") + + return NeptuneAnalyticsAdapter( + graph_id=graph_identifier, + ) + from .networkx.adapter import NetworkXAdapter graph_client = NetworkXAdapter(filename=graph_file_path) diff --git a/cognee/infrastructure/databases/graph/neptune_driver/__init__.py b/cognee/infrastructure/databases/graph/neptune_driver/__init__.py new file mode 100644 index 000000000..184363988 --- /dev/null +++ b/cognee/infrastructure/databases/graph/neptune_driver/__init__.py @@ -0,0 +1,15 @@ +"""Neptune Analytics Driver Module + +This module provides the Neptune Analytics adapter and utilities for interacting +with Amazon Neptune Analytics graph databases. +""" + +from .adapter import NeptuneGraphDB +from . import neptune_utils +from . import exceptions + +__all__ = [ + "NeptuneGraphDB", + "neptune_utils", + "exceptions", +] diff --git a/cognee/infrastructure/databases/graph/neptune_driver/adapter.py b/cognee/infrastructure/databases/graph/neptune_driver/adapter.py new file mode 100644 index 000000000..362c01584 --- /dev/null +++ b/cognee/infrastructure/databases/graph/neptune_driver/adapter.py @@ -0,0 +1,1445 @@ +"""Neptune Analytics Adapter for Graph Database""" + +import json +from typing import Optional, Any, List, Dict, Type, Tuple +from uuid import UUID +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.databases.graph.graph_db_interface import ( + GraphDBInterface, + record_graph_changes, + NodeData, + EdgeData, + Node, +) +from cognee.modules.storage.utils import JSONEncoder +from cognee.infrastructure.engine import DataPoint + +from .exceptions import ( + NeptuneAnalyticsConfigurationError, +) +from .neptune_utils import ( + validate_graph_id, + validate_aws_region, + build_neptune_config, + format_neptune_error, +) + +logger = get_logger("NeptuneGraphDB") + +try: + from langchain_aws import NeptuneAnalyticsGraph + LANGCHAIN_AWS_AVAILABLE = True +except ImportError: + logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.") + LANGCHAIN_AWS_AVAILABLE = False + +NEPTUNE_ENDPOINT_URL = "neptune-graph://" + +class NeptuneGraphDB(GraphDBInterface): + """ + Adapter for interacting with Amazon Neptune Analytics graph store. + This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library. + """ + _GRAPH_NODE_LABEL = "COGNEE_NODE" + + def __init__( + self, + graph_id: str, + region: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): + """ + Initialize the Neptune Analytics adapter. + + Parameters: + ----------- + - graph_id (str): The Neptune Analytics graph identifier + - region (Optional[str]): AWS region where the graph is located (default: us-east-1) + - aws_access_key_id (Optional[str]): AWS access key ID + - aws_secret_access_key (Optional[str]): AWS secret access key + - aws_session_token (Optional[str]): AWS session token for temporary credentials + + Raises: + ------- + - NeptuneAnalyticsConfigurationError: If configuration parameters are invalid + """ + # validate import + if not LANGCHAIN_AWS_AVAILABLE: + raise ImportError("langchain_aws is not available. Please install it to use Neptune Analytics.") + + # Validate configuration + if not validate_graph_id(graph_id): + raise NeptuneAnalyticsConfigurationError(f"Invalid graph ID: \"{graph_id}\"") + + if region and not validate_aws_region(region): + raise NeptuneAnalyticsConfigurationError(f"Invalid AWS region: \"{region}\"") + + self.graph_id = graph_id + self.region = region + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + + # Build configuration + self.config = build_neptune_config( + graph_id=self.graph_id, + region=self.region, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + ) + + # Initialize Neptune Analytics client using langchain_aws + self._client: NeptuneAnalyticsGraph = self._initialize_client() + logger.info(f"Initialized Neptune Analytics adapter for graph: \"{graph_id}\" in region: \"{self.region}\"") + + def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]: + """ + Initialize the Neptune Analytics client using langchain_aws. + + Returns: + -------- + - Optional[Any]: The Neptune Analytics client or None if not available + """ + try: + # Initialize the Neptune Analytics Graph client + client_config = { + "graph_identifier": self.graph_id, + } + # Add AWS credentials if provided + if self.region: + client_config["region_name"] = self.region + if self.aws_access_key_id: + client_config["aws_access_key_id"] = self.aws_access_key_id + if self.aws_secret_access_key: + client_config["aws_secret_access_key"] = self.aws_secret_access_key + if self.aws_session_token: + client_config["aws_session_token"] = self.aws_session_token + + client = NeptuneAnalyticsGraph(**client_config) + logger.info("Successfully initialized Neptune Analytics client") + return client + + except Exception as e: + raise NeptuneAnalyticsConfigurationError(f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}") + + @staticmethod + def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]: + """ + Serialize properties for Neptune Analytics storage. + Parameters: + ----------- + - properties (Dict[str, Any]): Properties to serialize. + Returns: + -------- + - Dict[str, Any]: Serialized properties. + """ + serialized_properties = {} + + for property_key, property_value in properties.items(): + if isinstance(property_value, UUID): + serialized_properties[property_key] = str(property_value) + continue + + if isinstance(property_value, dict) or isinstance(property_value, list): + serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder) + continue + + serialized_properties[property_key] = property_value + + return serialized_properties + + async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Any]: + """ + Execute a query against the Neptune Analytics database and return the results. + + Parameters: + ----------- + - query (str): The query string to execute against the database. + - params (Optional[Dict[str, Any]]): A dictionary of parameters to be used in the query. + + Returns: + -------- + - List[Any]: A list of results from the query execution. + """ + try: + # Execute the query using the Neptune Analytics client + # The langchain_aws NeptuneAnalyticsGraph supports openCypher queries + if params is None: + params = {} + logger.debug(f"executing na query:\nquery={query}\n") + result = self._client.query(query, params) + + # Convert the result to list format expected by the interface + if isinstance(result, list): + return result + elif isinstance(result, dict): + return [result] + else: + return [{"result": result}] + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Neptune Analytics query failed: {error_msg}") + raise Exception(f"Query execution failed: {error_msg}") + + async def add_node(self, node: DataPoint) -> None: + """ + Add a single node with specified properties to the graph. + + Parameters: + ----------- + - node_id (str): Unique identifier for the node being added. + - properties (Dict[str, Any]): A dictionary of properties associated with the node. + """ + try: + # Prepare node properties with the ID and graph type + serialized_properties = self._serialize_properties(node.model_dump()) + + query = f""" + MERGE (n:{self._GRAPH_NODE_LABEL} {{`~id`: $node_id}}) + ON CREATE SET n = $properties, n.updated_at = timestamp() + ON MATCH SET n += $properties, n.updated_at = timestamp() + RETURN n + """ + + params = { + "node_id": str(node.id), + "properties": serialized_properties, + } + + result = await self.query(query, params) + logger.debug(f"Successfully added/updated node: {node.id}") + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to add node {node.id}: {error_msg}") + raise Exception(f"Failed to add node: {error_msg}") + + @record_graph_changes + async def add_nodes(self, nodes: List[DataPoint]) -> None: + """ + Add multiple nodes to the graph in a single operation. + + Parameters: + ----------- + - nodes (List[DataPoint]): A list of DataPoint objects to be added to the graph. + """ + if not nodes: + logger.debug("No nodes to add") + return + + try: + # Build bulk node creation query using UNWIND + query = f""" + UNWIND $nodes AS node + MERGE (n:{self._GRAPH_NODE_LABEL} {{`~id`: node.node_id}}) + ON CREATE SET n = node.properties, n.updated_at = timestamp() + ON MATCH SET n += node.properties, n.updated_at = timestamp() + RETURN count(n) AS nodes_processed + """ + + # Prepare node data for bulk operation + params = { + "nodes": [ + { + "node_id": str(node.id), + "properties": self._serialize_properties(node.model_dump()), + } + for node in nodes + ] + } + result = await self.query(query, params) + + processed_count = result[0].get('nodes_processed', 0) if result else 0 + logger.debug(f"Successfully processed {processed_count} nodes in bulk operation") + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to add nodes in bulk: {error_msg}") + # Fallback to individual node creation + logger.info("Falling back to individual node creation") + for node in nodes: + try: + await self.add_node(node) + except Exception as node_error: + logger.error(f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}") + continue + + async def delete_node(self, node_id: str) -> None: + """ + Delete a specified node from the graph by its ID. + + Parameters: + ----------- + - node_id (str): Unique identifier for the node to delete. + """ + try: + # Build openCypher query to delete the node and all its relationships + query = f""" + MATCH (n:{self._GRAPH_NODE_LABEL}) + WHERE id(n) = $node_id + DETACH DELETE n + """ + + params = { + "node_id": node_id + } + + await self.query(query, params) + logger.debug(f"Successfully deleted node: {node_id}") + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to delete node {node_id}: {error_msg}") + raise Exception(f"Failed to delete node: {error_msg}") + + async def delete_nodes(self, node_ids: List[str]) -> None: + """ + Delete multiple nodes from the graph by their identifiers. + + Parameters: + ----------- + - node_ids (List[str]): A list of unique identifiers for the nodes to delete. + """ + if not node_ids: + logger.debug("No nodes to delete") + return + + try: + # Build bulk node deletion query using UNWIND + query = f""" + UNWIND $node_ids AS node_id + MATCH (n:{self._GRAPH_NODE_LABEL}) + WHERE id(n) = node_id + DETACH DELETE n + """ + + params = {"node_ids": node_ids} + await self.query(query, params) + logger.debug(f"Successfully deleted {len(node_ids)} nodes in bulk operation") + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to delete nodes in bulk: {error_msg}") + # Fallback to individual node deletion + logger.info("Falling back to individual node deletion") + for node_id in node_ids: + try: + await self.delete_node(node_id) + except Exception as node_error: + logger.error(f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}") + continue + + async def get_node(self, node_id: str) -> Optional[NodeData]: + """ + Retrieve a single node from the graph using its ID. + + Parameters: + ----------- + - node_id (str): Unique identifier of the node to retrieve. + + Returns: + -------- + - Optional[NodeData]: The node data if found, None otherwise. + """ + try: + # Build openCypher query to retrieve the node + query = f""" + MATCH (n:{self._GRAPH_NODE_LABEL}) + WHERE id(n) = $node_id + RETURN n + """ + params = {'node_id': node_id} + + result = await self.query(query, params) + + if result and len(result) == 1: + # Extract node properties from the result + logger.debug(f"Successfully retrieved node: {node_id}") + return result[0]["n"] + else: + if not result: + logger.debug(f"Node not found: {node_id}") + elif len(result) > 1: + logger.debug(f"Only one node expected, multiple returned: {node_id}") + return None + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to get node {node_id}: {error_msg}") + raise Exception(f"Failed to get node: {error_msg}") + + async def get_nodes(self, node_ids: List[str]) -> List[NodeData]: + """ + Retrieve multiple nodes from the graph using their IDs. + + Parameters: + ----------- + - node_ids (List[str]): A list of unique identifiers for the nodes to retrieve. + + Returns: + -------- + - List[NodeData]: A list of node data for the found nodes. + """ + if not node_ids: + logger.debug("No node IDs provided") + return [] + + try: + # Build bulk node-retrieval OpenCypher query using UNWIND + query = f""" + UNWIND $node_ids AS node_id + MATCH (n:{self._GRAPH_NODE_LABEL}) + WHERE id(n) = node_id + RETURN n + """ + + params = {"node_ids": node_ids} + result = await self.query(query, params) + + # Extract node data from results + nodes = [record["n"] for record in result] + + logger.debug(f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested") + return nodes + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to get nodes in bulk: {error_msg}") + # Fallback to individual node retrieval + logger.info("Falling back to individual node retrieval") + nodes = [] + for node_id in node_ids: + try: + node_data = await self.get_node(node_id) + if node_data: + nodes.append(node_data) + except Exception as node_error: + logger.error(f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}") + continue + return nodes + + + async def extract_node(self, node_id: str): + """ + Retrieve a single node based on its ID. + + Parameters: + ----------- + + - node_id (str): The ID of the node to retrieve. + + Returns: + -------- + + - Optional[Dict[str, Any]]: The requested node as a dictionary, or None if it does + not exist. + """ + results = await self.extract_nodes([node_id]) + + return results[0] if len(results) > 0 else None + + async def extract_nodes(self, node_ids: List[str]): + """ + Retrieve multiple nodes from the database by their IDs. + + Parameters: + ----------- + + - node_ids (List[str]): A list of IDs for the nodes to retrieve. + + Returns: + -------- + + A list of nodes represented as dictionaries. + """ + query = """ + UNWIND $node_ids AS id + MATCH (node) WHERE id(node) = id + RETURN node""" + + params = {"node_ids": node_ids} + + results = await self.query(query, params) + + return [result["node"] for result in results] + + async def add_edge( + self, + source_id: str, + target_id: str, + relationship_name: str, + properties: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Create a new edge between two nodes in the graph. + + Parameters: + ----------- + - source_id (str): The unique identifier of the source node. + - target_id (str): The unique identifier of the target node. + - relationship_name (str): The name of the relationship to be established by the edge. + - properties (Optional[Dict[str, Any]]): Optional dictionary of properties associated with the edge. + """ + try: + # Build openCypher query to create the edge + # First ensure both nodes exist, then create the relationship + + # Prepare edge properties + edge_props = properties or {} + serialized_properties = self._serialize_properties(edge_props) + + query = f""" + MATCH (source:{self._GRAPH_NODE_LABEL}) + WHERE id(source) = $source_id + MATCH (target:{self._GRAPH_NODE_LABEL}) + WHERE id(target) = $target_id + MERGE (source)-[r:{relationship_name}]->(target) + ON CREATE SET r = $properties, r.updated_at = timestamp() + ON MATCH SET r = $properties, r.updated_at = timestamp() + RETURN r + """ + + params = { + "source_id": source_id, + "target_id": target_id, + "properties": serialized_properties, + } + await self.query(query, params) + logger.debug(f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}") + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}") + raise Exception(f"Failed to add edge: {error_msg}") + + @record_graph_changes + async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None: + """ + Add multiple edges to the graph in a single operation. + + Parameters: + ----------- + - edges (List[EdgeData]): A list of EdgeData objects representing edges to be added. + """ + if not edges: + logger.debug("No edges to add") + return + + edges_by_relationship: dict[str, list] = {} + for edge in edges: + relationship_name = edge[2] + if edges_by_relationship.get(relationship_name, None): + edges_by_relationship[relationship_name].append(edge) + else: + edges_by_relationship[relationship_name] = [edge] + + results = {} + for relationship_name, edges_by_relationship in edges_by_relationship.items(): + try: + # Create the bulk-edge OpenCypher query using UNWIND + query = f""" + UNWIND $edges AS edge + MATCH (source:{self._GRAPH_NODE_LABEL}) + WHERE id(source) = edge.from_node + MATCH (target:{self._GRAPH_NODE_LABEL}) + WHERE id(target) = edge.to_node + MERGE (source)-[r:{relationship_name}]->(target) + ON CREATE SET r = edge.properties, r.updated_at = timestamp() + ON MATCH SET r = edge.properties, r.updated_at = timestamp() + RETURN count(*) AS edges_processed + """ + + # Prepare edges data for bulk operation + params = {"edges": + [ + { + "from_node": str(edge[0]), + "to_node": str(edge[1]), + "relationship_name": relationship_name, + "properties": self._serialize_properties(edge[3] if len(edge) > 3 and edge[3] else {}), + } + for edge in edges_by_relationship + ] + } + results[relationship_name] = await self.query(query, params) + except Exception as e: + logger.error(f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}") + logger.info("Falling back to individual edge creation") + for edge in edges_by_relationship: + try: + source_id, target_id, relationship_name = edge[0], edge[1], edge[2] + properties = edge[3] if len(edge) > 3 else {} + await self.add_edge(source_id, target_id, relationship_name, properties) + except Exception as edge_error: + logger.error(f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}") + continue + + processed_count = 0 + for result in results.values(): + processed_count += result[0].get('edges_processed', 0) if result else 0 + logger.debug(f"Successfully processed {processed_count} edges in bulk operation") + + + async def delete_graph(self) -> None: + """ + Delete all nodes and edges from the graph database. + + Returns: + -------- + The result of the query execution, typically indicating success or failure. + """ + try: + # Build openCypher query to delete the graph + query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n" + await self.query(query) + logger.debug(f"Successfully deleted all edges and nodes from the graph") + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to delete graph: {error_msg}") + raise Exception(f"Failed to delete graph: {error_msg}") + + + async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]: + """ + Retrieve all nodes and edges within the graph. + + Returns: + -------- + - Tuple[List[Node], List[EdgeData]]: A tuple containing all nodes and edges in the graph. + """ + try: + # Query to get all nodes + nodes_query = f""" + MATCH (n:{self._GRAPH_NODE_LABEL}) + RETURN id(n) AS node_id, properties(n) AS properties + """ + + # Query to get all edges + edges_query = f""" + MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL}) + RETURN id(source) AS source_id, id(target) AS target_id, type(r) AS relationship_name, properties(r) AS properties + """ + + # Execute both queries + nodes_result = await self.query(nodes_query) + edges_result = await self.query(edges_query) + + # Format nodes as (node_id, properties) tuples + nodes = [ + ( + result["node_id"], + result["properties"] + ) + for result in nodes_result + ] + + # Format edges as (source_id, target_id, relationship_name, properties) tuples + edges = [ + ( + result["source_id"], + result["target_id"], + result["relationship_name"], + result["properties"] + ) + for result in edges_result + ] + + logger.debug(f"Retrieved {len(nodes)} nodes and {len(edges)} edges from graph") + return (nodes, edges) + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to get graph data: {error_msg}") + raise Exception(f"Failed to get graph data: {error_msg}") + + async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]: + """ + Fetch metrics and statistics of the graph, possibly including optional details. + + Parameters: + ----------- + - include_optional (bool): Flag indicating whether to include optional metrics or not. + + Returns: + -------- + - Dict[str, Any]: A dictionary containing graph metrics and statistics. + """ + num_nodes, num_edges = await self._get_model_independent_graph_data() + num_cluster, list_clsuter_size = await self._get_connected_components_stat() + + mandatory_metrics = { + "num_nodes": num_nodes, + "num_edges": num_edges, + "mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None, + "edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1)) if num_nodes != 0 else None, + "num_connected_components": num_cluster, + "sizes_of_connected_components": list_clsuter_size + } + + optional_metrics = { + "num_selfloops": -1, + "diameter": -1, + "avg_shortest_path_length": -1, + "avg_clustering": -1, + } + + if include_optional: + optional_metrics['num_selfloops'] = await self._count_self_loops() + # Unsupported due to long-running queries when computing the shortest path for each node in the graph: + # optional_metrics['diameter'] + # optional_metrics['avg_shortest_path_length'] + # + # Unsupported due to incompatible algorithm: localClusteringCoefficient + # optional_metrics['avg_clustering'] + + return mandatory_metrics | optional_metrics + + async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool: + """ + Verify if an edge exists between two specified nodes. + + Parameters: + ----------- + - source_id (str): Unique identifier of the source node. + - target_id (str): Unique identifier of the target node. + - relationship_name (str): Name of the relationship to verify. + + Returns: + -------- + - bool: True if the edge exists, False otherwise. + """ + try: + # Build openCypher query to check if the edge exists + query = f""" + MATCH (source:{self._GRAPH_NODE_LABEL})-[r:{relationship_name}]->(target:{self._GRAPH_NODE_LABEL}) + WHERE id(source) = $source_id AND id(target) = $target_id + RETURN COUNT(r) > 0 AS edge_exists + """ + + params = { + "source_id": source_id, + "target_id": target_id, + } + + result = await self.query(query, params) + + if result and len(result) > 0: + edge_exists = result.pop().get('edge_exists', False) + logger.debug(f"Edge existence check for " + f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}") + return edge_exists + else: + return False + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to check edge existence {source_id} -> {target_id}: {error_msg}") + return False + + async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]: + """ + Determine the existence of multiple edges in the graph. + + Parameters: + ----------- + - edges (List[EdgeData]): A list of EdgeData objects to check for existence in the graph. + + Returns: + -------- + - List[EdgeData]: A list of EdgeData objects that exist in the graph. + """ + query = f""" + UNWIND $edges AS edge + MATCH (a:{self._GRAPH_NODE_LABEL})-[r]->(b:{self._GRAPH_NODE_LABEL}) + WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name + RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists + """ + + try: + params = { + "edges": [ + { + "from_node": str(edge[0]), + "to_node": str(edge[1]), + "relationship_name": edge[2], + } + for edge in edges + ], + } + + results = await self.query(query, params) + logger.debug(f"Found {len(results)} existing edges out of {len(edges)} checked") + return [result["edge_exists"] for result in results] + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to check edges existence: {error_msg}") + return [] + + async def get_edges(self, node_id: str) -> List[EdgeData]: + """ + Retrieve all edges that are connected to the specified node. + + Parameters: + ----------- + - node_id (str): Unique identifier of the node whose edges are to be retrieved. + + Returns: + -------- + - List[EdgeData]: A list of EdgeData objects representing edges connected to the node. + """ + try: + # Query to get all edges connected to the node (both incoming and outgoing) + query = f""" + MATCH (n:{self._GRAPH_NODE_LABEL})-[r]-(m:{self._GRAPH_NODE_LABEL}) + WHERE id(n) = $node_id + RETURN + id(n) AS source_id, + id(m) AS target_id, + type(r) AS relationship_name, + properties(r) AS properties + """ + + params = {"node_id": node_id} + result = await self.query(query, params) + + # Format edges as EdgeData tuples: (source_id, target_id, relationship_name, properties) + edges = [self._convert_relationship_to_edge(record) for record in result] + + logger.debug(f"Retrieved {len(edges)} edges for node: {node_id}") + return edges + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to get edges for node {node_id}: {error_msg}") + raise Exception(f"Failed to get edges: {error_msg}") + + async def get_disconnected_nodes(self) -> list[str]: + """ + Find and return nodes that are not connected to any other nodes in the graph. + + Returns: + -------- + + - list[str]: A list of IDs of disconnected nodes. + """ + query = f""" + MATCH(n :{self._GRAPH_NODE_LABEL}) + WHERE NOT (n)--() + RETURN COLLECT(ID(n)) as ids + """ + + results = await self.query(query) + return results[0]["ids"] if len(results) > 0 else [] + + async def get_predecessors(self, node_id: str, edge_label: str = "") -> list[str]: + """ + Retrieve the predecessor nodes of a specified node based on an optional edge label. + + Parameters: + ----------- + + - node_id (str): The ID of the node whose predecessors are to be retrieved. + - edge_label (str): Optional edge label to filter predecessors. (default None) + + Returns: + -------- + + - list[str]: A list of predecessor node IDs. + """ + + edge_label = f" :{edge_label}" if edge_label is not None else "" + query = f""" + MATCH (node)<-[r{edge_label}]-(predecessor) + WHERE node.id = $node_id + RETURN predecessor + """ + + results = await self.query( + query, + {"node_id": node_id} + ) + + return [result["predecessor"] for result in results] + + async def get_successors(self, node_id: str, edge_label: str = "") -> list[str]: + """ + Retrieve the successor nodes of a specified node based on an optional edge label. + + Parameters: + ----------- + + - node_id (str): The ID of the node whose successors are to be retrieved. + - edge_label (str): Optional edge label to filter successors. (default None) + + Returns: + -------- + + - list[str]: A list of successor node IDs. + """ + + edge_label = f" :{edge_label}" if edge_label is not None else "" + query = f""" + MATCH (node)-[r {edge_label}]->(successor) + WHERE node.id = $node_id + RETURN successor + """ + + results = await self.query( + query, + {"node_id": node_id} + ) + + return [result["successor"] for result in results] + + + async def get_neighbors(self, node_id: str) -> List[NodeData]: + """ + Get all neighboring nodes connected to the specified node. + + Parameters: + ----------- + - node_id (str): Unique identifier of the node for which to retrieve neighbors. + + Returns: + -------- + - List[NodeData]: A list of NodeData objects representing neighboring nodes. + """ + try: + # Query to get all neighboring nodes (both incoming and outgoing connections) + query = f""" + MATCH (n:{self._GRAPH_NODE_LABEL})-[r]-(neighbor:{self._GRAPH_NODE_LABEL}) + WHERE id(n) = $node_id + RETURN DISTINCT id(neighbor) AS neighbor_id, properties(neighbor) AS properties + """ + + params = {"node_id": node_id} + result = await self.query(query, params) + + # Format neighbors as NodeData objects + neighbors = [ + { + "id": neighbor["neighbor_id"], + **neighbor["properties"] + } + for neighbor in result + ] + + logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}") + return neighbors + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to get neighbors for node {node_id}: {error_msg}") + raise Exception(f"Failed to get neighbors: {error_msg}") + + async def get_nodeset_subgraph( + self, + node_type: Type[Any], + node_name: List[str] + ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: + """ + Fetch a subgraph consisting of a specific set of nodes and their relationships. + + Parameters: + ----------- + - node_type (Type[Any]): The type of nodes to include in the subgraph. + - node_name (List[str]): A list of names of the nodes to include in the subgraph. + + Returns: + -------- + - Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: A tuple containing nodes and edges of the subgraph. + """ + try: + # Query to get nodes by name and their connected subgraph + query = f""" + UNWIND $names AS wantedName + MATCH (n:{self._GRAPH_NODE_LABEL}) + WHERE n.name = wantedName AND n.type = $type + WITH collect(DISTINCT n) AS primary + UNWIND primary AS p + OPTIONAL MATCH (p)-[r]-(nbr:{self._GRAPH_NODE_LABEL}) + WITH primary, collect(DISTINCT nbr) AS nbrs, collect(DISTINCT r) AS rels + WITH primary + nbrs AS nodelist, rels + UNWIND nodelist AS node + WITH collect(DISTINCT node) AS nodes, rels + MATCH (a:{self._GRAPH_NODE_LABEL})-[r]-(b:{self._GRAPH_NODE_LABEL}) + WHERE a IN nodes AND b IN nodes + WITH nodes, collect(DISTINCT r) AS all_rels + RETURN + [n IN nodes | {{ + id: id(n), + properties: properties(n) + }}] AS rawNodes, + [r IN all_rels | {{ + source_id: id(startNode(r)), + target_id: id(endNode(r)), + type: type(r), + properties: properties(r) + }}] AS rawRels + """ + + params = { + "names": node_name, + "type": node_type.__name__ + } + + result = await self.query(query, params) + + if not result: + logger.debug(f"No subgraph found for node type {node_type} with names {node_name}") + return ([], []) + + raw_nodes = result[0]["rawNodes"] + raw_rels = result[0]["rawRels"] + + # Format nodes as (node_id, properties) tuples + nodes = [ + (n["id"], n["properties"]) + for n in raw_nodes + ] + + # Format edges as (source_id, target_id, relationship_name, properties) tuples + edges = [ + (r["source_id"], r["target_id"], r["type"], r["properties"]) + for r in raw_rels + ] + + logger.debug(f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}") + return (nodes, edges) + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to get nodeset subgraph for type {node_type}: {error_msg}") + raise Exception(f"Failed to get nodeset subgraph: {error_msg}") + + async def get_connections(self, node_id: UUID) -> list: + """ + Get all nodes connected to a specified node and their relationship details. + + Parameters: + ----------- + - node_id (str): Unique identifier of the node for which to retrieve connections. + + Returns: + -------- + - List[Tuple[NodeData, Dict[str, Any], NodeData]]: A list of tuples containing connected nodes and relationship details. + """ + try: + # Query to get all connections (both incoming and outgoing) + query = f""" + MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL}) + WHERE id(source) = $node_id OR id(target) = $node_id + RETURN + id(source) AS source_id, + properties(source) AS source_props, + id(target) AS target_id, + properties(target) AS target_props, + type(r) AS relationship_name, + properties(r) AS relationship_props + """ + + params = {"node_id": str(node_id)} + result = await self.query(query, params) + + connections = [] + for record in result: + # Return as (source_node, relationship, target_node) + connections.append( + ( + { + "id": record["source_id"], + **record["source_props"] + }, + { + "relationship_name": record["relationship_name"], + **record["relationship_props"] + }, + { + "id": record["target_id"], + **record["target_props"] + } + ) + ) + + logger.debug(f"Retrieved {len(connections)} connections for node: {node_id}") + return connections + + except Exception as e: + error_msg = format_neptune_error(e) + logger.error(f"Failed to get connections for node {node_id}: {error_msg}") + raise Exception(f"Failed to get connections: {error_msg}") + + + async def remove_connection_to_predecessors_of( + self, node_ids: list[str], edge_label: str + ): + """ + Remove connections (edges) to all predecessors of specified nodes based on edge label. + + Parameters: + ----------- + + - node_ids (list[str]): A list of IDs of nodes from which connections are to be + removed. + - edge_label (str): The label of the edges to remove. + + """ + query = f""" + UNWIND $node_ids AS node_id + MATCH ({{`~id`: node_id}})-[r:{edge_label}]->(predecessor) + DELETE r; + """ + params = {"node_ids": node_ids} + await self.query(query, params) + + + async def remove_connection_to_successors_of( + self, node_ids: list[str], edge_label: str + ): + """ + Remove connections (edges) to all successors of specified nodes based on edge label. + + Parameters: + ----------- + + - node_ids (list[str]): A list of IDs of nodes from which connections are to be + removed. + - edge_label (str): The label of the edges to remove. + + """ + query = f""" + UNWIND $node_ids AS node_id + MATCH ({{`~id`: node_id}})<-[r:{edge_label}]-(successor) + DELETE r; + """ + params = {"node_ids": node_ids} + await self.query(query, params) + + + async def get_node_labels_string(self): + """ + Fetch all node labels from the database and return them as a formatted string. + + Returns: + -------- + + A formatted string of node labels. + + Raises: + ------- + ValueError: If no node labels are found in the database. + """ + node_labels_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels " + node_labels_result = await self.query(node_labels_query) + node_labels = node_labels_result[0]["labels"] if node_labels_result else [] + + if not node_labels: + raise ValueError("No node labels found in the database") + + return str(node_labels) + + async def get_relationship_labels_string(self): + """ + Fetch all relationship types from the database and return them as a formatted string. + + Returns: + -------- + + A formatted string of relationship types. + """ + relationship_types_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships " + relationship_types_result = await self.query(relationship_types_query) + relationship_types = ( + relationship_types_result[0]["relationships"] if relationship_types_result else [] + ) + + if not relationship_types: + raise ValueError("No relationship types found in the database.") + + relationship_types_undirected_str = ( + "{" + + ", ".join(f"{rel}" + ": {orientation: 'UNDIRECTED'}" for rel in relationship_types) + + "}" + ) + return relationship_types_undirected_str + + async def project_entire_graph(self, graph_name="myGraph"): + """ + Project all node labels and relationship types into an in-memory graph using GDS. + + Note: This method is currently a placeholder because GDS (Graph Data Science) + projection is not supported in Neptune Analytics. + """ + pass + + async def drop_graph(self, graph_name="myGraph"): + """ + Drop an existing graph from the database based on its name. + + Note: This method is currently a placeholder because GDS (Graph Data Science) + projection is not supported in Neptune Analytics. + + Parameters: + ----------- + + - graph_name: The name of the graph to drop, defaults to 'myGraph'. (default + 'myGraph') + """ + pass + + async def graph_exists(self, graph_name="myGraph"): + """ + Check if a graph with a given name exists in the database. + + Note: This method is currently a placeholder because GDS (Graph Data Science) + projection is not supported in Neptune Analytics. + + Parameters: + ----------- + + - graph_name: The name of the graph to check for existence, defaults to 'myGraph'. + (default 'myGraph') + + Returns: + -------- + + True if the graph exists, otherwise False. + """ + pass + + + async def project_entire_graph(self, graph_name="myGraph"): + """ + Project all node labels and relationship types into an in-memory graph using GDS. + + Note: This method is currently a placeholder because GDS (Graph Data Science) + projection is not supported in Neptune Anlaytics. + """ + pass + + async def get_filtered_graph_data(self, attribute_filters: list[dict[str, list]]): + """ + Fetch nodes and edges filtered by specific attribute criteria. + + Parameters: + ----------- + + - attribute_filters: A list of dictionaries representing attributes and associated + values for filtering. + + Returns: + -------- + + A tuple containing filtered nodes and edges based on the specified criteria. + """ + where_clauses_n = [] + where_clauses_m = [] + for attribute, values in attribute_filters[0].items(): + values_str = ", ".join( + f"'{value}'" if isinstance(value, str) else str(value) for value in values + ) + where_clauses_n.append(f"n.{attribute} IN [{values_str}]") + where_clauses_m.append(f"m.{attribute} IN [{values_str}]") + + node_where_clauses_n_str = " AND ".join(where_clauses_n) + node_where_clauses_m_str = " AND ".join(where_clauses_m) + edge_where_clause = f"{node_where_clauses_n_str} AND {node_where_clauses_m_str}" + + query_nodes = f""" + MATCH (n :{self._GRAPH_NODE_LABEL}) + WHERE {node_where_clauses_n_str} + RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties + """ + result_nodes = await self.query(query_nodes) + + nodes = [ + ( + record["id"], + record["properties"], + ) + for record in result_nodes + ] + + query_edges = f""" + MATCH (n :{self._GRAPH_NODE_LABEL})-[r]->(m :{self._GRAPH_NODE_LABEL}) + WHERE {edge_where_clause} + RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties + """ + result_edges = await self.query(query_edges) + + edges = [ + ( + record["source"], + record["target"], + record["type"], + record["properties"], + ) + for record in result_edges + ] + + return (nodes, edges) + + + async def get_degree_one_nodes(self, node_type: str): + """ + Fetch nodes of a specified type that have exactly one connection. + + Parameters: + ----------- + + - node_type (str): The type of nodes to retrieve, must be 'Entity' or 'EntityType'. + + Returns: + -------- + + A list of nodes with exactly one connection of the specified type. + """ + if not node_type or node_type not in ["Entity", "EntityType"]: + raise ValueError("node_type must be either 'Entity' or 'EntityType'") + + query = f""" + MATCH (n :{self._GRAPH_NODE_LABEL}) + WHERE size((n)--()) = 1 + AND n.type = $node_type + RETURN n + """ + result = await self.query(query, {"node_type": node_type}) + return [record["n"] for record in result] if result else [] + + async def get_document_subgraph(self, content_hash: str): + """ + Retrieve a subgraph related to a document identified by its content hash, including + related entities and chunks. + + Parameters: + ----------- + + - content_hash (str): The hash identifying the document whose subgraph should be + retrieved. + + Returns: + -------- + + The subgraph data as a dictionary, or None if not found. + """ + query = f""" + + MATCH (doc) + WHERE (doc:{self._GRAPH_NODE_LABEL}) + AND doc.type in ['TextDocument', 'PdfDocument'] + AND doc.name = 'text_' + $content_hash + + OPTIONAL MATCH (doc)<-[:is_part_of]-(chunk {{type: 'DocumentChunk'}}) + + // Alternative to WHERE NOT EXISTS + OPTIONAL MATCH (chunk)-[:contains]->(entity {{type: 'Entity'}}) + OPTIONAL MATCH (entity)<-[:contains]-(otherChunk {{type: 'DocumentChunk'}})-[:is_part_of]->(otherDoc) + WHERE otherDoc.type in ['TextDocument', 'PdfDocument'] + AND otherDoc.id <> doc.id + OPTIONAL MATCH (chunk)<-[:made_from]-(made_node {{type: 'TextSummary'}}) + + OPTIONAL MATCH (chunk)<-[:made_from]-(made_node {{type: 'TextSummary'}}) + + // Alternative to WHERE NOT EXISTS + OPTIONAL MATCH (entity)-[:is_a]->(type {{type: 'EntityType'}}) + OPTIONAL MATCH (type)<-[:is_a]-(otherEntity {{type: 'Entity'}})<-[:contains]-(otherChunk {{type: 'DocumentChunk'}})-[:is_part_of]->(otherDoc) + WHERE otherDoc.type in ['TextDocument', 'PdfDocument'] + AND otherDoc.id <> doc.id + + // Alternative to WHERE NOT EXISTS + WITH doc, entity, chunk, made_node, type, otherDoc + WHERE otherDoc IS NULL + + RETURN + collect(DISTINCT doc) as document, + collect(DISTINCT chunk) as chunks, + collect(DISTINCT entity) as orphan_entities, + collect(DISTINCT made_node) as made_from_nodes, + collect(DISTINCT type) as orphan_types + """ + result = await self.query(query, {"content_hash": content_hash}) + return result[0] if result else None + + + + async def _get_model_independent_graph_data(self): + """ + Retrieve the basic graph data without considering the model specifics, returning nodes + and edges. + + Returns: + -------- + + A tuple of nodes and edges data. + """ + query_string = f""" + MATCH (n :{self._GRAPH_NODE_LABEL}) + WITH count(n) AS nodeCount + MATCH (a :{self._GRAPH_NODE_LABEL})-[r]->(b :{self._GRAPH_NODE_LABEL}) + RETURN nodeCount AS numVertices, count(r) AS numEdges + """ + query_response = await self.query(query_string) + num_nodes = query_response[0].get('numVertices') + num_edges = query_response[0].get('numEdges') + + return (num_nodes, num_edges) + + async def _get_connected_components_stat(self): + """ + Retrieve statistics about connected components in the graph. + + This method analyzes the graph to find all connected components + and returns both the sizes of each component and the total number of components. + + + Returns: + -------- + tuple[list[int], int] + A tuple containing: + - A list of sizes for each connected component (descending order). + - The total number of connected components. + Returns ([], 0) if no connected components are found. + """ + query = f""" + MATCH(n :{self._GRAPH_NODE_LABEL}) + CALL neptune.algo.wcc(n,{{}}) + YIELD node, component + RETURN component, count(*) AS size + ORDER BY size DESC + """ + + result = await self.query(query) + size_connected_components = [record["size"] for record in result] if result else [] + num_connected_components = len(result) + + return (size_connected_components, num_connected_components) + + async def _count_self_loops(self): + """ + Count the number of self-loop relationships in the Neptune Anlaytics graph backend. + + This function executes a OpenCypher query to find and count all edge relationships that + begin and end at the same node (self-loops). It returns the count of such relationships + or 0 if no results are found. + + Returns: + -------- + + The count of self-loop relationships found in the database, or 0 if none were found. + """ + query = f""" + MATCH (n :{self._GRAPH_NODE_LABEL})-[r]->(n :{self._GRAPH_NODE_LABEL}) + RETURN count(r) AS adapter_loop_count; + """ + result = await self.query(query) + return result[0]["adapter_loop_count"] if result else 0 + + @staticmethod + def _convert_relationship_to_edge(relationship: dict) -> EdgeData: + return relationship["source_id"], relationship["target_id"], relationship["relationship_name"], relationship["properties"] diff --git a/cognee/infrastructure/databases/graph/neptune_driver/exceptions.py b/cognee/infrastructure/databases/graph/neptune_driver/exceptions.py new file mode 100644 index 000000000..984f2c9f0 --- /dev/null +++ b/cognee/infrastructure/databases/graph/neptune_driver/exceptions.py @@ -0,0 +1,49 @@ +"""Neptune Analytics Exceptions + +This module defines custom exceptions for Neptune Analytics operations. +""" + + +class NeptuneAnalyticsError(Exception): + """Base exception for Neptune Analytics operations.""" + pass + + +class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError): + """Exception raised when connection to Neptune Analytics fails.""" + pass + + +class NeptuneAnalyticsQueryError(NeptuneAnalyticsError): + """Exception raised when a query execution fails.""" + pass + + +class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError): + """Exception raised when authentication with Neptune Analytics fails.""" + pass + + +class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError): + """Exception raised when Neptune Analytics configuration is invalid.""" + pass + + +class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError): + """Exception raised when a Neptune Analytics operation times out.""" + pass + + +class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError): + """Exception raised when requests are throttled by Neptune Analytics.""" + pass + + +class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError): + """Exception raised when a Neptune Analytics resource is not found.""" + pass + + +class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError): + """Exception raised when invalid parameters are provided to Neptune Analytics.""" + pass diff --git a/cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py b/cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py new file mode 100644 index 000000000..b70f2b1fa --- /dev/null +++ b/cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py @@ -0,0 +1,221 @@ +"""Neptune Utilities + +This module provides utility functions for Neptune Analytics operations including +connection management, URL parsing, and Neptune-specific configurations. +""" + +import re +from typing import Optional, Dict, Any, Tuple +from urllib.parse import urlparse + +from cognee.shared.logging_utils import get_logger + +logger = get_logger("NeptuneUtils") + + +def parse_neptune_url(url: str) -> Tuple[str, str]: + """ + Parse a Neptune Analytics URL to extract graph ID and region. + + Expected format: neptune-graph://?region= + or neptune-graph:// (defaults to us-east-1) + + Parameters: + ----------- + - url (str): The Neptune Analytics URL to parse + + Returns: + -------- + - Tuple[str, str]: A tuple containing (graph_id, region) + + Raises: + ------- + - ValueError: If the URL format is invalid + """ + try: + parsed = urlparse(url) + + if parsed.scheme != "neptune-graph": + raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'") + + graph_id = parsed.hostname or parsed.path.lstrip('/') + if not graph_id: + raise ValueError("Graph ID not found in URL") + + # Extract region from query parameters + region = "us-east-1" # default region + if parsed.query: + query_params = dict(param.split('=') for param in parsed.query.split('&') if '=' in param) + region = query_params.get('region', region) + + return graph_id, region + + except Exception as e: + raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}") + + +def validate_graph_id(graph_id: str) -> bool: + """ + Validate a Neptune Analytics graph ID format. + + Graph IDs should follow AWS naming conventions. + + Parameters: + ----------- + - graph_id (str): The graph ID to validate + + Returns: + -------- + - bool: True if the graph ID is valid, False otherwise + """ + if not graph_id: + return False + + # Neptune Analytics graph IDs should be alphanumeric with hyphens + # and between 1-63 characters + pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$' + return bool(re.match(pattern, graph_id)) + + +def validate_aws_region(region: str) -> bool: + """ + Validate an AWS region format. + + Parameters: + ----------- + - region (str): The AWS region to validate + + Returns: + -------- + - bool: True if the region format is valid, False otherwise + """ + if not region: + return False + + # AWS regions follow the pattern: us-east-1, eu-west-1, etc. + pattern = r'^[a-z]{2,3}-[a-z]+-\d+$' + return bool(re.match(pattern, region)) + + +def build_neptune_config( + graph_id: str, + region: Optional[str], + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + **kwargs +) -> Dict[str, Any]: + """ + Build a configuration dictionary for Neptune Analytics connection. + + Parameters: + ----------- + - graph_id (str): The Neptune Analytics graph identifier + - region (Optional[str]): AWS region where the graph is located + - aws_access_key_id (Optional[str]): AWS access key ID + - aws_secret_access_key (Optional[str]): AWS secret access key + - aws_session_token (Optional[str]): AWS session token for temporary credentials + - **kwargs: Additional configuration parameters + + Returns: + -------- + - Dict[str, Any]: Configuration dictionary for Neptune Analytics + + Raises: + ------- + - ValueError: If required parameters are invalid + """ + config = { + "graph_id": graph_id, + "service_name": "neptune-graph", + } + + # Add AWS credentials if provided + if region: + config["region"] = region + + if aws_access_key_id: + config["aws_access_key_id"] = aws_access_key_id + + if aws_secret_access_key: + config["aws_secret_access_key"] = aws_secret_access_key + + if aws_session_token: + config["aws_session_token"] = aws_session_token + + # Add any additional configuration + config.update(kwargs) + + return config + + +def get_neptune_endpoint_url(graph_id: str, region: str) -> str: + """ + Construct the Neptune Analytics endpoint URL for a given graph and region. + + Parameters: + ----------- + - graph_id (str): The Neptune Analytics graph identifier + - region (str): AWS region where the graph is located + + Returns: + -------- + - str: The Neptune Analytics endpoint URL + """ + return f"https://neptune-graph.{region}.amazonaws.com/graphs/{graph_id}" + + +def format_neptune_error(error: Exception) -> str: + """ + Format Neptune Analytics specific errors for better readability. + + Parameters: + ----------- + - error (Exception): The exception to format + + Returns: + -------- + - str: Formatted error message + """ + error_msg = str(error) + + # Common Neptune Analytics error patterns and their user-friendly messages + error_mappings = { + "AccessDenied": "Access denied. Please check your AWS credentials and permissions.", + "GraphNotFound": "Graph not found. Please verify the graph ID and region.", + "InvalidParameter": "Invalid parameter provided. Please check your request parameters.", + "ThrottlingException": "Request was throttled. Please retry with exponential backoff.", + "InternalServerError": "Internal server error occurred. Please try again later.", + } + + for error_type, friendly_msg in error_mappings.items(): + if error_type in error_msg: + return f"{friendly_msg} Original error: {error_msg}" + + return error_msg + +def get_default_query_timeout() -> int: + """ + Get the default query timeout for Neptune Analytics operations. + + Returns: + -------- + - int: Default timeout in seconds + """ + return 300 # 5 minutes + + +def get_default_connection_config() -> Dict[str, Any]: + """ + Get default connection configuration for Neptune Analytics. + + Returns: + -------- + - Dict[str, Any]: Default connection configuration + """ + return { + "query_timeout": get_default_query_timeout(), + "max_retries": 3, + "retry_delay": 1.0, + "preferred_query_language": "openCypher", + } diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py new file mode 100644 index 000000000..b48bae773 --- /dev/null +++ b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py @@ -0,0 +1,436 @@ +"""Neptune Analytics Hybrid Adapter combining Vector and Graph functionality""" + +import asyncio +import json +from typing import List, Optional, Any, Dict, Type, Tuple +from uuid import UUID + +from cognee.exceptions import InvalidValueError +from cognee.infrastructure.databases.graph.neptune_driver.adapter import NeptuneGraphDB +from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface +from cognee.infrastructure.engine import DataPoint +from cognee.modules.storage.utils import JSONEncoder +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine +from cognee.infrastructure.databases.vector.models.PayloadSchema import PayloadSchema +from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult + +logger = get_logger("NeptuneAnalyticsAdapter") + +class IndexSchema(DataPoint): + """ + Represents a schema for an index data point containing an ID and text. + + Attributes: + - id: A string representing the unique identifier for the data point. + - text: A string representing the content of the data point. + - metadata: A dictionary with default index fields for the schema, currently configured + to include 'text'. + """ + id: str + text: str + metadata: dict = {"index_fields": ["text"]} + +NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://" + +class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): + """ + Hybrid adapter that combines Neptune Analytics Vector and Graph functionality. + + This adapter extends NeptuneGraphDB and implements VectorDBInterface to provide + a unified interface for working with Neptune Analytics as both a vector store + and a graph database. + """ + + _VECTOR_NODE_LABEL = "COGNEE_NODE" + _COLLECTION_PREFIX = "VECTOR_COLLECTION" + _TOPK_LOWER_BOUND = 0 + _TOPK_UPPER_BOUND = 10 + + def __init__( + self, + graph_id: str, + embedding_engine: Optional[EmbeddingEngine] = None, + region: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): + """ + Initialize the Neptune Analytics hybrid adapter. + + Parameters: + ----------- + - graph_id (str): The Neptune Analytics graph identifier + - embedding_engine(Optional[EmbeddingEngine]): The embedding engine instance to translate text to vector. + - region (Optional[str]): AWS region where the graph is located (default: us-east-1) + - aws_access_key_id (Optional[str]): AWS access key ID + - aws_secret_access_key (Optional[str]): AWS secret access key + - aws_session_token (Optional[str]): AWS session token for temporary credentials + """ + # Initialize the graph database functionality + super().__init__( + graph_id=graph_id, + region=region, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token + ) + + # Add vector-specific attributes + self.embedding_engine = embedding_engine + logger.info(f"Initialized Neptune Analytics hybrid adapter for graph: \"{graph_id}\" in region: \"{self.region}\"") + + # VectorDBInterface methods implementation + + async def get_connection(self): + """ + This method is part of the default implementation but not defined in the interface. + No operation is performed and None will be returned here, + because the concept of connection is not applicable in this context. + """ + return None + + async def embed_data(self, data: list[str]) -> list[list[float]]: + """ + Embeds the provided textual data into vector representation. + + Uses the embedding engine to convert the list of strings into a list of float vectors. + + Parameters: + ----------- + - data (list[str]): A list of strings representing the data to be embedded. + + Returns: + -------- + - list[list[float]]: A list of embedded vectors corresponding to the input data. + """ + self._validate_embedding_engine() + return await self.embedding_engine.embed_text(data) + + async def has_collection(self, collection_name: str) -> bool: + """ + Neptune Analytics stores vector on a node level, + so create_collection() implements interface for compliance but performs no operations when called. + + Parameters: + ----------- + - collection_name (str): The name of the collection to check for existence. + Returns: + -------- + - bool: Always return True. + """ + return True + + async def create_collection( + self, + collection_name: str, + payload_schema: Optional[PayloadSchema] = None, + ): + """ + Neptune Analytics stores vector on a node level, so create_collection() implements interface for compliance but performs no operations when called. + As the result, create_collection() will be no-op. + + Parameters: + ----------- + - collection_name (str): The name of the new collection to create. + - payload_schema (Optional[PayloadSchema]): An optional schema for the payloads + within this collection. (default None) + """ + pass + + async def get_collection(self, collection_name: str): + """ + This method is part of the default implementation but not defined in the interface. + No operation is performed here because the concept of collection is not applicable in NeptuneAnalytics vector store. + """ + return None + + async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): + """ + Insert new data points into the specified collection, by first inserting the node itself on the graph, + then execute neptune.algo.vectors.upsert() to insert the corresponded embedding. + + Parameters: + ----------- + - collection_name (str): The name of the collection where data points will be added. + - data_points (List[DataPoint]): A list of data points to be added to the + collection. + """ + self._validate_embedding_engine() + + # Fetch embeddings + texts = [DataPoint.get_embeddable_data(t) for t in data_points] + data_vectors = (await self.embedding_engine.embed_text(texts)) + + for index, data_point in enumerate(data_points): + node_id = data_point.id + # Fetch embedding from list instead + data_vector = data_vectors[index] + + # Fetch properties + properties = self._serialize_properties(data_point.model_dump()) + properties[self._COLLECTION_PREFIX] = collection_name + params = dict( + node_id = str(node_id), + properties = properties, + embedding = data_vector, + collection_name = collection_name + ) + + # Compose the query and send + query_string = ( + f"MERGE (n " + f":{self._VECTOR_NODE_LABEL} " + f" {{`~id`: $node_id}}) " + f"ON CREATE SET n = $properties, n.updated_at = timestamp() " + f"ON MATCH SET n += $properties, n.updated_at = timestamp() " + f"WITH n, $embedding AS embedding " + f"CALL neptune.algo.vectors.upsert(n, embedding) " + f"YIELD success " + f"RETURN success ") + + try: + self._client.query(query_string, params) + except Exception as e: + self._na_exception_handler(e, query_string) + pass + + async def retrieve(self, collection_name: str, data_point_ids: list[str]): + """ + Retrieve data points from a collection using their IDs. + + Parameters: + ----------- + - collection_name (str): The name of the collection from which to retrieve data + points. + - data_point_ids (list[str]): A list of IDs of the data points to retrieve. + """ + # Do the fetch for each node + params = dict(node_ids=data_point_ids, collection_name=collection_name) + query_string = (f"MATCH( n :{self._VECTOR_NODE_LABEL}) " + f"WHERE id(n) in $node_ids AND " + f"n.{self._COLLECTION_PREFIX} = $collection_name " + f"RETURN n as payload ") + + try: + result = self._client.query(query_string, params) + return [self._get_scored_result(item) for item in result] + except Exception as e: + self._na_exception_handler(e, query_string) + + async def search( + self, + collection_name: str, + query_text: Optional[str] = None, + query_vector: Optional[List[float]] = None, + limit: int = None, + with_vector: bool = False, + ): + """ + Perform a search in the specified collection using either a text query or a vector + query. + + Parameters: + ----------- + - collection_name (str): The name of the collection in which to perform the search. + - query_text (Optional[str]): An optional text query to search for in the + collection. + - query_vector (Optional[List[float]]): An optional vector representation for + searching the collection. + - limit (int): The maximum number of results to return from the search. + - with_vector (bool): Whether to return the vector representations with search + results, this is not supported for Neptune Analytics backend at the moment. + + Returns: + -------- + A list of scored results that match the query. + """ + self._validate_embedding_engine() + + if with_vector: + logger.warning( + "with_vector=True will include embedding vectors in the result. " + "This may trigger a resource-intensive query and increase response time. " + "Use this option only when vector data is required." + ) + + # In the case of excessive limit, or zero / negative value, limit will be set to 10. + if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND: + logger.warning( + "Provided limit (%s) is invalid (zero, negative, or exceeds maximum). " + "Defaulting to limit=10.", limit + ) + limit = self._TOPK_UPPER_BOUND + + if query_vector and query_text: + raise InvalidValueError( + message="The search function accepts either text or embedding as input, but not both." + ) + elif query_text is None and query_vector is None: + raise InvalidValueError(message="One of query_text or query_vector must be provided!") + elif query_vector: + embedding = query_vector + else: + data_vectors = (await self.embedding_engine.embed_text([query_text])) + embedding = data_vectors[0] + + # Compose the parameters map + params = dict(embedding=embedding, param_topk=limit) + # Compose the query + query_string = f""" + CALL neptune.algo.vectors.topKByEmbeddingWithFiltering({{ + topK: {limit}, + embedding: {embedding}, + nodeFilter: {{ equals: {{property: '{self._COLLECTION_PREFIX}', value: '{collection_name}'}} }} + }} + ) + YIELD node, score + """ + + if with_vector: + query_string += """ + WITH node, score, id(node) as node_id + MATCH (n) + WHERE id(n) = id(node) + CALL neptune.algo.vectors.get(n) + YIELD embedding + RETURN node as payload, score, embedding + """ + + else: + query_string += """ + RETURN node as payload, score + """ + + try: + query_response = self._client.query(query_string, params) + return [self._get_scored_result( + item = item, with_score = True + ) for item in query_response] + except Exception as e: + self._na_exception_handler(e, query_string) + + async def batch_search( + self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False + ): + """ + Perform a batch search using multiple text queries against a collection. + + Parameters: + ----------- + - collection_name (str): The name of the collection to conduct the batch search in. + - query_texts (List[str]): A list of text queries to use for the search. + - limit (int): The maximum number of results to return for each query. + - with_vectors (bool): Whether to include vector representations with search + results. (default False) + + Returns: + -------- + A list of search result sets, one for each query input. + """ + self._validate_embedding_engine() + + # Convert text to embedding array in batch + data_vectors = (await self.embedding_engine.embed_text(query_texts)) + return await asyncio.gather(*[ + self.search(collection_name, None, vector, limit, with_vectors) + for vector in data_vectors + ]) + + async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + """ + Delete specified data points from a collection, by executing an OpenCypher query, + with matching [vector_label, collection_label, node_id] combination. + + Parameters: + ----------- + - collection_name (str): The name of the collection from which to delete data + points. + - data_point_ids (list[str]): A list of IDs of the data points to delete. + """ + params = dict(node_ids=data_point_ids, collection_name=collection_name) + query_string = (f"MATCH (n :{self._VECTOR_NODE_LABEL}) " + f"WHERE id(n) IN $node_ids " + f"AND n.{self._COLLECTION_PREFIX} = $collection_name " + f"DETACH DELETE n") + try: + self._client.query(query_string, params) + except Exception as e: + self._na_exception_handler(e, query_string) + pass + + async def create_vector_index(self, index_name: str, index_property_name: str): + """ + Neptune Analytics stores vectors at the node level, + so create_vector_index() implements the interface for compliance but performs no operation when called. + As a result, create_vector_index() invokes create_collection(), which is also a no-op. + This ensures the logic flow remains consistent, even if the concept of collections is introduced in a future release. + """ + await self.create_collection(f"{index_name}_{index_property_name}") + + async def index_data_points( + self, index_name: str, index_property_name: str, data_points: list[DataPoint] + ): + """ + Indexes a list of data points into Neptune Analytics by creating them as nodes. + + This method constructs a unique collection name by combining the `index_name` and + `index_property_name`, then delegates to `create_data_points()` to store the data. + + Args: + index_name (str): The base name of the index. + index_property_name (str): The property name to append to the index name for uniqueness. + data_points (list[DataPoint]): A list of `DataPoint` instances to be indexed. + + Returns: + None + """ + await self.create_data_points( + f"{index_name}_{index_property_name}", + [ + IndexSchema( + id=str(data_point.id), + text=getattr(data_point, data_point.metadata["index_fields"][0]), + ) + for data_point in data_points + ], + ) + + async def prune(self): + """ + Remove obsolete or unnecessary data from the database. + """ + # Run actual truncate + self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) " + f"DETACH DELETE n") + pass + + @staticmethod + def _get_scored_result(item: dict, with_vector: bool = False, with_score: bool = False) -> ScoredResult: + """ + Util method to simplify the object creation of ScoredResult base on incoming NX payload response. + """ + return ScoredResult( + id=item.get('payload').get('~id'), + payload=item.get('payload').get('~properties'), + score=item.get('score') if with_score else 0, + vector=item.get('embedding') if with_vector else None + ) + + def _na_exception_handler(self, ex, query_string: str): + """ + Generic exception handler for NA langchain. + """ + logger.error( + "Neptune Analytics query failed: %s | Query: [%s]", ex, query_string + ) + raise ex + + def _validate_embedding_engine(self): + """ + Validates if the embedding_engine is defined + :raises: ValueError if this object does not have a valid embedding_engine + """ + if self.embedding_engine is None: + raise ValueError("Neptune Analytics requires an embedder defined to make vector operations") diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/__init__.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 8bbceaf7f..7c335e6f7 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -114,6 +114,28 @@ def create_vector_engine( embedding_engine=embedding_engine, ) + elif vector_db_provider == "neptune_analytics": + try: + from langchain_aws import NeptuneAnalyticsGraph + except ImportError: + raise ImportError( + "langchain_aws is not installed. Please install it with 'pip install langchain_aws'" + ) + + if not vector_db_url: + raise EnvironmentError("Missing Neptune endpoint.") + + from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL + if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL): + raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}'") + + graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "") + + return NeptuneAnalyticsAdapter( + graph_id=graph_identifier, + embedding_engine=embedding_engine, + ) + else: from .lancedb.LanceDBAdapter import LanceDBAdapter diff --git a/cognee/tests/test_neptune_analytics_graph.py b/cognee/tests/test_neptune_analytics_graph.py new file mode 100644 index 000000000..b32019798 --- /dev/null +++ b/cognee/tests/test_neptune_analytics_graph.py @@ -0,0 +1,313 @@ +import os +from dotenv import load_dotenv +import asyncio +from cognee.infrastructure.databases.graph.neptune_driver import NeptuneGraphDB +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.data.processing.document_types import TextDocument + +# Set up Amazon credentials in .env file and get the values from environment variables +load_dotenv() +graph_id = os.getenv('GRAPH_ID', "") + +na_adapter = NeptuneGraphDB(graph_id) + + +def setup(): + # Define nodes data before the main function + # These nodes were defined using openAI from the following prompt: + + # Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads + # that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It + # complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load + # the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's + # stored in Amazon S3. + + document = TextDocument( + name='text_test.txt', + raw_data_location='git/cognee/examples/database_examples/data_storage/data/text_test.txt', + external_metadata='{}', + mime_type='text/plain' + ) + document_chunk = DocumentChunk( + text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + chunk_size=187, + chunk_index=0, + cut_type='paragraph_end', + is_part_of=document, + ) + + graph_database = EntityType(name='graph database', description='graph database') + neptune_analytics_entity = Entity( + name='neptune analytics', + description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.', + ) + neptune_database_entity = Entity( + name='amazon neptune database', + description='A popular managed graph database that complements Neptune Analytics.', + ) + + storage = EntityType(name='storage', description='storage') + storage_entity = Entity( + name='amazon s3', + description='A storage service provided by Amazon Web Services that allows storing graph data.', + ) + + nodes_data = [ + document, + document_chunk, + graph_database, + neptune_analytics_entity, + neptune_database_entity, + storage, + storage_entity, + ] + + edges_data = [ + ( + str(document_chunk.id), + str(storage_entity.id), + 'contains', + ), + ( + str(storage_entity.id), + str(storage.id), + 'is_a', + ), + ( + str(document_chunk.id), + str(neptune_database_entity.id), + 'contains', + ), + ( + str(neptune_database_entity.id), + str(graph_database.id), + 'is_a', + ), + ( + str(document_chunk.id), + str(document.id), + 'is_part_of', + ), + ( + str(document_chunk.id), + str(neptune_analytics_entity.id), + 'contains', + ), + ( + str(neptune_analytics_entity.id), + str(graph_database.id), + 'is_a', + ), + ] + + return nodes_data, edges_data + + +async def pipeline_method(): + """ + Example script using the neptune analytics with small sample data + + This example demonstrates how to add nodes to Neptune Analytics + """ + + print("------TRUNCATE GRAPH-------") + await na_adapter.delete_graph() + + print("------SETUP DATA-------") + nodes, edges = setup() + + print("------ADD NODES-------") + await na_adapter.add_node(nodes[0]) + await na_adapter.add_nodes(nodes[1:]) + + print("------GET NODES FROM DATA-------") + node_ids = [str(node.id) for node in nodes] + db_nodes = await na_adapter.get_nodes(node_ids) + + print("------RESULTS:-------") + for n in db_nodes: + print(n) + + print("------ADD EDGES-------") + await na_adapter.add_edge(edges[0][0], edges[0][1], edges[0][2]) + await na_adapter.add_edges(edges[1:]) + + print("------HAS EDGES-------") + has_edge = await na_adapter.has_edge( + edges[0][0], + edges[0][1], + edges[0][2], + ) + if has_edge: + print(f"found edge ({edges[0][0]})-[{edges[0][2]}]->({edges[0][1]})") + + has_edges = await na_adapter.has_edges(edges) + if len(has_edges) > 0: + print(f"found edges: {len(has_edges)} (expected: {len(edges)})") + else: + print(f"no edges found (expected: {len(edges)})") + + print("------GET GRAPH-------") + all_nodes, all_edges = await na_adapter.get_graph_data() + print(f"found {len(all_nodes)} nodes and found {len(all_edges)} edges") + + print("------NEIGHBORING NODES-------") + center_node = nodes[2] + neighbors = await na_adapter.get_neighbors(str(center_node.id)) + print(f"found {len(neighbors)} neighbors for node \"{center_node.name}\"") + for neighbor in neighbors: + print(neighbor) + + print("------NEIGHBORING EDGES-------") + center_node = nodes[2] + neighbouring_edges = await na_adapter.get_edges(str(center_node.id)) + print(f"found {len(neighbouring_edges)} edges neighbouring node \"{center_node.name}\"") + for edge in neighbouring_edges: + print(edge) + + print("------GET CONNECTIONS (SOURCE NODE)-------") + document_chunk_node = nodes[0] + connections = await na_adapter.get_connections(str(document_chunk_node.id)) + print(f"found {len(connections)} connections for node \"{document_chunk_node.type}\"") + for connection in connections: + src, relationship, tgt = connection + src = src.get("name", src.get("type", "unknown")) + relationship = relationship["relationship_name"] + tgt = tgt.get("name", tgt.get("type", "unknown")) + print(f"\"{src}\"-[{relationship}]->\"{tgt}\"") + + print("------GET CONNECTIONS (TARGET NODE)-------") + connections = await na_adapter.get_connections(str(center_node.id)) + print(f"found {len(connections)} connections for node \"{center_node.name}\"") + for connection in connections: + src, relationship, tgt = connection + src = src.get("name", src.get("type", "unknown")) + relationship = relationship["relationship_name"] + tgt = tgt.get("name", tgt.get("type", "unknown")) + print(f"\"{src}\"-[{relationship}]->\"{tgt}\"") + + print("------SUBGRAPH-------") + node_names = ["neptune analytics", "amazon neptune database"] + subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names) + print(f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}") + for subgraph_node in subgraph_nodes: + print(subgraph_node) + for subgraph_edge in subgraph_edges: + print(subgraph_edge) + + print("------STAT-------") + stat = await na_adapter.get_graph_metrics(include_optional=True) + assert type(stat) is dict + assert stat['num_nodes'] == 7 + assert stat['num_edges'] == 7 + assert stat['mean_degree'] == 2.0 + assert round(stat['edge_density'], 3) == 0.167 + assert stat['num_connected_components'] == [7] + assert stat['sizes_of_connected_components'] == 1 + assert stat['num_selfloops'] == 0 + # Unsupported optional metrics + assert stat['diameter'] == -1 + assert stat['avg_shortest_path_length'] == -1 + assert stat['avg_clustering'] == -1 + + print("------DELETE-------") + # delete all nodes and edges: + await na_adapter.delete_graph() + + # delete all nodes by node id + # node_ids = [str(node.id) for node in nodes] + # await na_adapter.delete_nodes(node_ids) + + has_edges = await na_adapter.has_edges(edges) + if len(has_edges) == 0: + print(f"Delete successful") + else: + print(f"Delete failed") + + +async def misc_methods(): + print("------TRUNCATE GRAPH-------") + await na_adapter.delete_graph() + + print("------SETUP TEST ENV-------") + nodes, edges = setup() + await na_adapter.add_nodes(nodes) + await na_adapter.add_edges(edges) + + print("------GET GRAPH-------") + all_nodes, all_edges = await na_adapter.get_graph_data() + print(f"found {len(all_nodes)} nodes and found {len(all_edges)} edges") + + print("------GET DISCONNECTED-------") + nodes_disconnected = await na_adapter.get_disconnected_nodes() + print(nodes_disconnected) + assert len(nodes_disconnected) == 0 + + print("------Get Labels (Node)-------") + node_labels = await na_adapter.get_node_labels_string() + print(node_labels) + + print("------Get Labels (Edge)-------") + edge_labels = await na_adapter.get_relationship_labels_string() + print(edge_labels) + + print("------Get Filtered Graph-------") + filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data([{'name': ['text_test.txt']}]) + print(filtered_nodes, filtered_edges) + + print("------Get Degree one nodes-------") + degree_one_nodes = await na_adapter.get_degree_one_nodes("EntityType") + print(degree_one_nodes) + + print("------Get Doc sub-graph-------") + doc_sub_graph = await na_adapter.get_document_subgraph('test.txt') + print(doc_sub_graph) + + print("------Fetch and Remove connections (Predecessors)-------") + # Fetch test edge + (src_id, dest_id, relationship) = edges[0] + nodes_predecessors = await na_adapter.get_predecessors( + node_id=dest_id, edge_label=relationship + ) + assert len(nodes_predecessors) > 0 + + await na_adapter.remove_connection_to_predecessors_of( + node_ids=[src_id], edge_label=relationship + ) + nodes_predecessors_after = await na_adapter.get_predecessors( + node_id=dest_id, edge_label=relationship + ) + # Return empty after relationship being deleted. + assert len(nodes_predecessors_after) == 0 + + + print("------Fetch and Remove connections (Successors)-------") + _, edges_suc = await na_adapter.get_graph_data() + (src_id, dest_id, relationship, _) = edges_suc[0] + + nodes_successors = await na_adapter.get_successors( + node_id=src_id, edge_label=relationship + ) + assert len(nodes_successors) > 0 + + await na_adapter.remove_connection_to_successors_of( + node_ids=[dest_id], edge_label=relationship + ) + nodes_successors_after = await na_adapter.get_successors( + node_id=src_id, edge_label=relationship + ) + assert len(nodes_successors_after) == 0 + + + # no-op + await na_adapter.project_entire_graph() + await na_adapter.drop_graph() + await na_adapter.graph_exists() + + pass + + +if __name__ == "__main__": + asyncio.run(pipeline_method()) + asyncio.run(misc_methods()) diff --git a/cognee/tests/test_neptune_analytics_hybrid.py b/cognee/tests/test_neptune_analytics_hybrid.py new file mode 100644 index 000000000..352d20fba --- /dev/null +++ b/cognee/tests/test_neptune_analytics_hybrid.py @@ -0,0 +1,169 @@ +import os +from dotenv import load_dotenv +import asyncio +import pytest + +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.data.processing.document_types import TextDocument +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter + +# Set up Amazon credentials in .env file and get the values from environment variables +load_dotenv() +graph_id = os.getenv('GRAPH_ID', "") + +# get the default embedder +embedding_engine = get_embedding_engine() +na_graph = NeptuneAnalyticsAdapter(graph_id) +na_vector = NeptuneAnalyticsAdapter(graph_id, embedding_engine) + +collection = "test_collection" + +logger = get_logger("test_neptune_analytics_hybrid") + +def setup_data(): + # Define nodes data before the main function + # These nodes were defined using openAI from the following prompt: + # + # Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads + # that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It + # complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load + # the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's + # stored in Amazon S3. + + document = TextDocument( + name='text.txt', + raw_data_location='git/cognee/examples/database_examples/data_storage/data/text.txt', + external_metadata='{}', + mime_type='text/plain' + ) + document_chunk = DocumentChunk( + text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + chunk_size=187, + chunk_index=0, + cut_type='paragraph_end', + is_part_of=document, + ) + + graph_database = EntityType(name='graph database', description='graph database') + neptune_analytics_entity = Entity( + name='neptune analytics', + description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.', + ) + neptune_database_entity = Entity( + name='amazon neptune database', + description='A popular managed graph database that complements Neptune Analytics.', + ) + + storage = EntityType(name='storage', description='storage') + storage_entity = Entity( + name='amazon s3', + description='A storage service provided by Amazon Web Services that allows storing graph data.', + ) + + nodes_data = [ + document, + document_chunk, + graph_database, + neptune_analytics_entity, + neptune_database_entity, + storage, + storage_entity, + ] + + edges_data = [ + ( + str(document_chunk.id), + str(storage_entity.id), + 'contains', + ), + ( + str(storage_entity.id), + str(storage.id), + 'is_a', + ), + ( + str(document_chunk.id), + str(neptune_database_entity.id), + 'contains', + ), + ( + str(neptune_database_entity.id), + str(graph_database.id), + 'is_a', + ), + ( + str(document_chunk.id), + str(document.id), + 'is_part_of', + ), + ( + str(document_chunk.id), + str(neptune_analytics_entity.id), + 'contains', + ), + ( + str(neptune_analytics_entity.id), + str(graph_database.id), + 'is_a', + ), + ] + return nodes_data, edges_data + +async def test_add_graph_then_vector_data(): + logger.info("------test_add_graph_then_vector_data-------") + (nodes, edges) = setup_data() + await na_graph.add_nodes(nodes) + await na_graph.add_edges(edges) + await na_vector.create_data_points(collection, nodes) + + node_ids = [str(node.id) for node in nodes] + retrieved_data_points = await na_vector.retrieve(collection, node_ids) + retrieved_nodes = await na_graph.get_nodes(node_ids) + + assert len(retrieved_data_points) == len(retrieved_nodes) == len(node_ids) + + # delete all nodes and edges and vectors: + await na_graph.delete_graph() + await na_vector.prune() + + (nodes, edges) = await na_graph.get_graph_data() + assert len(nodes) == 0 + assert len(edges) == 0 + logger.info("------PASSED-------") + +async def test_add_vector_then_node_data(): + logger.info("------test_add_vector_then_node_data-------") + (nodes, edges) = setup_data() + await na_vector.create_data_points(collection, nodes) + await na_graph.add_nodes(nodes) + await na_graph.add_edges(edges) + + node_ids = [str(node.id) for node in nodes] + retrieved_data_points = await na_vector.retrieve(collection, node_ids) + retrieved_nodes = await na_graph.get_nodes(node_ids) + + assert len(retrieved_data_points) == len(retrieved_nodes) == len(node_ids) + + # delete all nodes and edges and vectors: + await na_vector.prune() + await na_graph.delete_graph() + + (nodes, edges) = await na_graph.get_graph_data() + assert len(nodes) == 0 + assert len(edges) == 0 + logger.info("------PASSED-------") + +def main(): + """ + Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data + This example demonstrates how to add nodes and vectors to Neptune Analytics, and ensures that + the nodes do not conflict + """ + asyncio.run(test_add_graph_then_vector_data()) + asyncio.run(test_add_vector_then_node_data()) + +if __name__ == "__main__": + main() diff --git a/cognee/tests/test_neptune_analytics_vector.py b/cognee/tests/test_neptune_analytics_vector.py new file mode 100644 index 000000000..3dda125f3 --- /dev/null +++ b/cognee/tests/test_neptune_analytics_vector.py @@ -0,0 +1,172 @@ +import os +import pathlib +import cognee +import uuid +import pytest +from cognee.modules.search.operations import get_history +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import get_logger +from cognee.modules.search.types import SearchType +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, IndexSchema + +logger = get_logger() + + +async def main(): + graph_id = os.getenv('GRAPH_ID', "") + cognee.config.set_vector_db_provider("neptune_analytics") + cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}") + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_neptune") + ).resolve() + ) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_neptune") + ).resolve() + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + dataset_name = "cs_explanations" + + explanation_file_path = os.path.join( + pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt" + ) + await cognee.add([explanation_file_path], dataset_name) + + text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena. + At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states. + Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible. + The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly. + Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate. + In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited. + """ + + await cognee.add([text], dataset_name) + + await cognee.cognify([dataset_name]) + + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node_name = random_node.payload["text"] + + search_results = await cognee.search( + query_type=SearchType.INSIGHTS, query_text=random_node_name + ) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted sentences are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted chunks are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search( + query_type=SearchType.SUMMARIES, query_text=random_node_name + ) + assert len(search_results) != 0, "Query related summaries don't exist." + print("\nExtracted summaries are:\n") + for result in search_results: + print(f"{result}\n") + + user = await get_default_user() + history = await get_history(user.id) + assert len(history) == 6, "Search history is not correct." + + await cognee.prune.prune_data() + assert not os.path.isdir(data_directory_path), "Local data files are not deleted" + + await cognee.prune.prune_system(metadata=True) + +async def vector_backend_api_test(): + cognee.config.set_vector_db_provider("neptune_analytics") + + # When URL is absent + cognee.config.set_vector_db_url(None) + with pytest.raises(OSError): + get_vector_engine() + + # Assert invalid graph ID. + cognee.config.set_vector_db_url("invalid_url") + with pytest.raises(ValueError): + get_vector_engine() + + # Return a valid engine object with valid URL. + graph_id = os.getenv('GRAPH_ID', "") + cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}") + engine = get_vector_engine() + assert isinstance(engine, NeptuneAnalyticsAdapter) + + TEST_COLLECTION_NAME = "test" + # Data point - 1 + TEST_UUID = str(uuid.uuid4()) + TEST_TEXT = "Hello world" + datapoint = IndexSchema(id=TEST_UUID, text=TEST_TEXT) + # Data point - 2 + TEST_UUID_2 = str(uuid.uuid4()) + TEST_TEXT_2 = "Cognee" + datapoint_2 = IndexSchema(id=TEST_UUID_2, text=TEST_TEXT_2) + + # Prun all vector_db entries + await engine.prune() + + # Always return true + has_collection = await engine.has_collection(TEST_COLLECTION_NAME) + assert has_collection + # No-op + await engine.create_collection(TEST_COLLECTION_NAME, IndexSchema) + + # Save data-points + await engine.create_data_points(TEST_COLLECTION_NAME, [datapoint, datapoint_2]) + # Search single text + result_search = await engine.search( + collection_name=TEST_COLLECTION_NAME, + query_text=TEST_TEXT, + query_vector=None, + limit=10, + with_vector=True) + assert (len(result_search) == 2) + + # # Retrieve data-points + result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2]) + assert any( + str(r.id) == TEST_UUID and r.payload['text'] == TEST_TEXT + for r in result + ) + assert any( + str(r.id) == TEST_UUID_2 and r.payload['text'] == TEST_TEXT_2 + for r in result + ) + # Search multiple + result_search_batch = await engine.batch_search( + collection_name=TEST_COLLECTION_NAME, + query_texts=[TEST_TEXT, TEST_TEXT_2], + limit=10, + with_vectors=False + ) + assert (len(result_search_batch) == 2 and + all(len(batch) == 2 for batch in result_search_batch)) + + # Delete datapoint from vector store + await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2]) + + # Retrieve should return an empty list. + result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID]) + assert result_deleted == [] + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) + asyncio.run(vector_backend_api_test()) diff --git a/examples/database_examples/neptune_analytics_example.py b/examples/database_examples/neptune_analytics_example.py new file mode 100644 index 000000000..acc2baedb --- /dev/null +++ b/examples/database_examples/neptune_analytics_example.py @@ -0,0 +1,107 @@ +import base64 +import json +import os +import pathlib +import asyncio +import cognee +from cognee.modules.search.types import SearchType +from dotenv import load_dotenv + +load_dotenv() + +async def main(): + """ + Example script demonstrating how to use Cognee with Amazon Neptune Analytics + + This example: + 1. Configures Cognee to use Neptune Analytics as graph database + 2. Sets up data directories + 3. Adds sample data to Cognee + 4. Processes/cognifies the data + 5. Performs different types of searches + """ + + # Set up Amazon credentials in .env file and get the values from environment variables + graph_endpoint_url = "neptune-graph://" + os.getenv('GRAPH_ID', "") + + # Configure Neptune Analytics as the graph & vector database provider + cognee.config.set_graph_db_config( + { + "graph_database_provider": "neptune_analytics", # Specify Neptune Analytics as provider + "graph_database_url": graph_endpoint_url, # Neptune Analytics endpoint with the format neptune-graph:// + } + ) + cognee.config.set_vector_db_config( + { + "vector_db_provider": "neptune_analytics", # Specify Neptune Analytics as provider + "vector_db_url": graph_endpoint_url, # Neptune Analytics endpoint with the format neptune-graph:// + } + ) + + # Set up data directories for storing documents and system files + # You should adjust these paths to your needs + current_dir = pathlib.Path(__file__).parent + data_directory_path = str(current_dir / "data_storage") + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = str(current_dir / "cognee_system") + cognee.config.system_root_directory(cognee_directory_path) + + # Clean any existing data (optional) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + # Create a dataset + dataset_name = "neptune_example" + + # Add sample text to the dataset + sample_text_1 = """Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune + Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. To analyze + graph data quickly and easily, Neptune Analytics stores large graph datasets in memory. It supports a library of + optimized graph analytic algorithms, low-latency graph queries, and vector search capabilities within graph + traversals. + """ + + sample_text_2 = """Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads + that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It + complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load + the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's + stored in Amazon S3. + """ + + # Add the sample text to the dataset + await cognee.add([sample_text_1, sample_text_2], dataset_name) + + # Process the added document to extract knowledge + await cognee.cognify([dataset_name]) + + # Now let's perform some searches + # 1. Search for insights related to "Neptune Analytics" + insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neptune Analytics") + print("\n========Insights about Neptune Analytics========:") + for result in insights_results: + print(f"- {result}") + + # 2. Search for text chunks related to "graph database" + chunks_results = await cognee.search( + query_type=SearchType.CHUNKS, query_text="graph database", datasets=[dataset_name] + ) + print("\n========Chunks about graph database========:") + for result in chunks_results: + print(f"- {result}") + + # 3. Get graph completion related to databases + graph_completion_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text="database" + ) + print("\n========Graph completion for databases========:") + for result in graph_completion_results: + print(f"- {result}") + + # Clean up (optional) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 33e45d88c..2ac359ad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ distributed = [ qdrant = ["qdrant-client>=1.14.2,<2"] neo4j = ["neo4j>=5.28.0,<6"] +neptune = ["langchain_aws>=0.2.22"] postgres = [ "psycopg2>=2.9.10,<3", "pgvector>=0.3.5,<0.4",