From c0d60eef25bb67c120fdceef93fa6a4d8635f816 Mon Sep 17 00:00:00 2001 From: vasilije Date: Sat, 2 Aug 2025 16:58:18 +0200 Subject: [PATCH] added linting and formatting --- .../databases/graph/get_graph_engine.py | 13 +- .../databases/graph/neptune_driver/adapter.py | 229 +++++++++--------- .../graph/neptune_driver/exceptions.py | 30 ++- .../graph/neptune_driver/neptune_utils.py | 81 ++++--- .../NeptuneAnalyticsAdapter.py | 125 +++++----- .../databases/vector/create_vector_engine.py | 10 +- cognee/tests/test_neptune_analytics_graph.py | 98 ++++---- cognee/tests/test_neptune_analytics_hybrid.py | 53 ++-- cognee/tests/test_neptune_analytics_vector.py | 31 ++- .../neptune_analytics_example.py | 7 +- .../python/weighted_graph_visualization.html | 4 +- 11 files changed, 355 insertions(+), 326 deletions(-) diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 4255c062e..662c96eed 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -149,7 +149,9 @@ def create_graph_engine( from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL): - raise ValueError(f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}") + raise ValueError( + f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}" + ) graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "") @@ -174,10 +176,15 @@ def create_graph_engine( if not graph_database_url: raise EnvironmentError("Missing Neptune endpoint.") - from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL + from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import ( + NeptuneAnalyticsAdapter, + NEPTUNE_ANALYTICS_ENDPOINT_URL, + ) if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL): - raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}'") + raise ValueError( + f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}'" + ) graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "") diff --git a/cognee/infrastructure/databases/graph/neptune_driver/adapter.py b/cognee/infrastructure/databases/graph/neptune_driver/adapter.py index 12a963ca6..26fcd9c51 100644 --- a/cognee/infrastructure/databases/graph/neptune_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neptune_driver/adapter.py @@ -29,6 +29,7 @@ logger = get_logger("NeptuneGraphDB") try: from langchain_aws import NeptuneAnalyticsGraph + LANGCHAIN_AWS_AVAILABLE = True except ImportError: logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.") @@ -36,11 +37,13 @@ except ImportError: NEPTUNE_ENDPOINT_URL = "neptune-graph://" + class NeptuneGraphDB(GraphDBInterface): """ Adapter for interacting with Amazon Neptune Analytics graph store. This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library. """ + _GRAPH_NODE_LABEL = "COGNEE_NODE" def __init__( @@ -61,28 +64,30 @@ class NeptuneGraphDB(GraphDBInterface): - aws_access_key_id (Optional[str]): AWS access key ID - aws_secret_access_key (Optional[str]): AWS secret access key - aws_session_token (Optional[str]): AWS session token for temporary credentials - + Raises: ------- - NeptuneAnalyticsConfigurationError: If configuration parameters are invalid """ # validate import if not LANGCHAIN_AWS_AVAILABLE: - raise ImportError("langchain_aws is not available. Please install it to use Neptune Analytics.") + raise ImportError( + "langchain_aws is not available. Please install it to use Neptune Analytics." + ) # Validate configuration if not validate_graph_id(graph_id): - raise NeptuneAnalyticsConfigurationError(message=f"Invalid graph ID: \"{graph_id}\"") - + raise NeptuneAnalyticsConfigurationError(message=f'Invalid graph ID: "{graph_id}"') + if region and not validate_aws_region(region): - raise NeptuneAnalyticsConfigurationError(message=f"Invalid AWS region: \"{region}\"") - + raise NeptuneAnalyticsConfigurationError(message=f'Invalid AWS region: "{region}"') + self.graph_id = graph_id self.region = region self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token - + # Build configuration self.config = build_neptune_config( graph_id=self.graph_id, @@ -91,15 +96,17 @@ class NeptuneGraphDB(GraphDBInterface): aws_secret_access_key=self.aws_secret_access_key, aws_session_token=self.aws_session_token, ) - + # Initialize Neptune Analytics client using langchain_aws self._client: NeptuneAnalyticsGraph = self._initialize_client() - logger.info(f"Initialized Neptune Analytics adapter for graph: \"{graph_id}\" in region: \"{self.region}\"") + logger.info( + f'Initialized Neptune Analytics adapter for graph: "{graph_id}" in region: "{self.region}"' + ) def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]: """ Initialize the Neptune Analytics client using langchain_aws. - + Returns: -------- - Optional[Any]: The Neptune Analytics client or None if not available @@ -108,9 +115,7 @@ class NeptuneGraphDB(GraphDBInterface): # Initialize the Neptune Analytics Graph client client_config = { "graph_identifier": self.graph_id, - "config": Config( - user_agent_appid='Cognee' - ) + "config": Config(user_agent_appid="Cognee"), } # Add AWS credentials if provided if self.region: @@ -121,13 +126,15 @@ class NeptuneGraphDB(GraphDBInterface): client_config["aws_secret_access_key"] = self.aws_secret_access_key if self.aws_session_token: client_config["aws_session_token"] = self.aws_session_token - + client = NeptuneAnalyticsGraph(**client_config) logger.info("Successfully initialized Neptune Analytics client") return client - + except Exception as e: - raise NeptuneAnalyticsConfigurationError(message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}") from e + raise NeptuneAnalyticsConfigurationError( + message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}" + ) from e @staticmethod def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]: @@ -175,7 +182,7 @@ class NeptuneGraphDB(GraphDBInterface): params = {} logger.debug(f"executing na query:\nquery={query}\n") result = self._client.query(query, params) - + # Convert the result to list format expected by the interface if isinstance(result, list): return result @@ -183,7 +190,7 @@ class NeptuneGraphDB(GraphDBInterface): return [result] else: return [{"result": result}] - + except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Neptune Analytics query failed: {error_msg}") @@ -212,10 +219,11 @@ class NeptuneGraphDB(GraphDBInterface): "node_id": str(node.id), "properties": serialized_properties, } - + result = await self.query(query, params) logger.debug(f"Successfully added/updated node: {node.id}") - + logger.debug(f"Successfully gotten: {str(result)}") + except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Failed to add node {node.id}: {error_msg}") @@ -256,7 +264,7 @@ class NeptuneGraphDB(GraphDBInterface): } result = await self.query(query, params) - processed_count = result[0].get('nodes_processed', 0) if result else 0 + processed_count = result[0].get("nodes_processed", 0) if result else 0 logger.debug(f"Successfully processed {processed_count} nodes in bulk operation") except Exception as e: @@ -268,7 +276,9 @@ class NeptuneGraphDB(GraphDBInterface): try: await self.add_node(node) except Exception as node_error: - logger.error(f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}") + logger.error( + f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}" + ) continue async def delete_node(self, node_id: str) -> None: @@ -287,13 +297,11 @@ class NeptuneGraphDB(GraphDBInterface): DETACH DELETE n """ - params = { - "node_id": node_id - } - + params = {"node_id": node_id} + await self.query(query, params) logger.debug(f"Successfully deleted node: {node_id}") - + except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Failed to delete node {node_id}: {error_msg}") @@ -333,7 +341,9 @@ class NeptuneGraphDB(GraphDBInterface): try: await self.delete_node(node_id) except Exception as node_error: - logger.error(f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}") + logger.error( + f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}" + ) continue async def get_node(self, node_id: str) -> Optional[NodeData]: @@ -355,10 +365,10 @@ class NeptuneGraphDB(GraphDBInterface): WHERE id(n) = $node_id RETURN n """ - params = {'node_id': node_id} + params = {"node_id": node_id} result = await self.query(query, params) - + if result and len(result) == 1: # Extract node properties from the result logger.debug(f"Successfully retrieved node: {node_id}") @@ -369,7 +379,7 @@ class NeptuneGraphDB(GraphDBInterface): elif len(result) > 1: logger.debug(f"Only one node expected, multiple returned: {node_id}") return None - + except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Failed to get node {node_id}: {error_msg}") @@ -406,7 +416,9 @@ class NeptuneGraphDB(GraphDBInterface): # Extract node data from results nodes = [record["n"] for record in result] - logger.debug(f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested") + logger.debug( + f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested" + ) return nodes except Exception as e: @@ -421,11 +433,12 @@ class NeptuneGraphDB(GraphDBInterface): if node_data: nodes.append(node_data) except Exception as node_error: - logger.error(f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}") + logger.error( + f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}" + ) continue return nodes - async def extract_node(self, node_id: str): """ Retrieve a single node based on its ID. @@ -512,8 +525,10 @@ class NeptuneGraphDB(GraphDBInterface): "properties": serialized_properties, } await self.query(query, params) - logger.debug(f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}") - + logger.debug( + f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}" + ) + except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}") @@ -557,20 +572,24 @@ class NeptuneGraphDB(GraphDBInterface): """ # Prepare edges data for bulk operation - params = {"edges": - [ + params = { + "edges": [ { "from_node": str(edge[0]), "to_node": str(edge[1]), "relationship_name": relationship_name, - "properties": self._serialize_properties(edge[3] if len(edge) > 3 and edge[3] else {}), + "properties": self._serialize_properties( + edge[3] if len(edge) > 3 and edge[3] else {} + ), } for edge in edges_for_relationship ] } results[relationship_name] = await self.query(query, params) except Exception as e: - logger.error(f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}") + logger.error( + f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}" + ) logger.info("Falling back to individual edge creation") for edge in edges_by_relationship: try: @@ -578,15 +597,16 @@ class NeptuneGraphDB(GraphDBInterface): properties = edge[3] if len(edge) > 3 else {} await self.add_edge(source_id, target_id, relationship_name, properties) except Exception as edge_error: - logger.error(f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}") + logger.error( + f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}" + ) continue processed_count = 0 for result in results.values(): - processed_count += result[0].get('edges_processed', 0) if result else 0 + processed_count += result[0].get("edges_processed", 0) if result else 0 logger.debug(f"Successfully processed {processed_count} edges in bulk operation") - async def delete_graph(self) -> None: """ Delete all nodes and edges from the graph database. @@ -599,14 +619,12 @@ class NeptuneGraphDB(GraphDBInterface): # Build openCypher query to delete the graph query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n" await self.query(query) - logger.debug(f"Successfully deleted all edges and nodes from the graph") except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Failed to delete graph: {error_msg}") raise Exception(f"Failed to delete graph: {error_msg}") from e - async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]: """ Retrieve all nodes and edges within the graph. @@ -633,13 +651,7 @@ class NeptuneGraphDB(GraphDBInterface): edges_result = await self.query(edges_query) # Format nodes as (node_id, properties) tuples - nodes = [ - ( - result["node_id"], - result["properties"] - ) - for result in nodes_result - ] + nodes = [(result["node_id"], result["properties"]) for result in nodes_result] # Format edges as (source_id, target_id, relationship_name, properties) tuples edges = [ @@ -647,7 +659,7 @@ class NeptuneGraphDB(GraphDBInterface): result["source_id"], result["target_id"], result["relationship_name"], - result["properties"] + result["properties"], ) for result in edges_result ] @@ -679,9 +691,11 @@ class NeptuneGraphDB(GraphDBInterface): "num_nodes": num_nodes, "num_edges": num_edges, "mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None, - "edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1)) if num_nodes != 0 else None, + "edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1)) + if num_nodes != 0 + else None, "num_connected_components": num_cluster, - "sizes_of_connected_components": list_clsuter_size + "sizes_of_connected_components": list_clsuter_size, } optional_metrics = { @@ -692,7 +706,7 @@ class NeptuneGraphDB(GraphDBInterface): } if include_optional: - optional_metrics['num_selfloops'] = await self._count_self_loops() + optional_metrics["num_selfloops"] = await self._count_self_loops() # Unsupported due to long-running queries when computing the shortest path for each node in the graph: # optional_metrics['diameter'] # optional_metrics['avg_shortest_path_length'] @@ -728,17 +742,19 @@ class NeptuneGraphDB(GraphDBInterface): "source_id": source_id, "target_id": target_id, } - + result = await self.query(query, params) - + if result and len(result) > 0: - edge_exists = result.pop().get('edge_exists', False) - logger.debug(f"Edge existence check for " - f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}") + edge_exists = result.pop().get("edge_exists", False) + logger.debug( + f"Edge existence check for " + f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}" + ) return edge_exists else: return False - + except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Failed to check edge existence {source_id} -> {target_id}: {error_msg}") @@ -778,7 +794,7 @@ class NeptuneGraphDB(GraphDBInterface): results = await self.query(query, params) logger.debug(f"Found {len(results)} existing edges out of {len(edges)} checked") return [result["edge_exists"] for result in results] - + except Exception as e: error_msg = format_neptune_error(e) logger.error(f"Failed to check edges existence: {error_msg}") @@ -863,10 +879,7 @@ class NeptuneGraphDB(GraphDBInterface): RETURN predecessor """ - results = await self.query( - query, - {"node_id": node_id} - ) + results = await self.query(query, {"node_id": node_id}) return [result["predecessor"] for result in results] @@ -893,14 +906,10 @@ class NeptuneGraphDB(GraphDBInterface): RETURN successor """ - results = await self.query( - query, - {"node_id": node_id} - ) + results = await self.query(query, {"node_id": node_id}) return [result["successor"] for result in results] - async def get_neighbors(self, node_id: str) -> List[NodeData]: """ Get all neighboring nodes connected to the specified node. @@ -926,11 +935,7 @@ class NeptuneGraphDB(GraphDBInterface): # Format neighbors as NodeData objects neighbors = [ - { - "id": neighbor["neighbor_id"], - **neighbor["properties"] - } - for neighbor in result + {"id": neighbor["neighbor_id"], **neighbor["properties"]} for neighbor in result ] logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}") @@ -942,9 +947,7 @@ class NeptuneGraphDB(GraphDBInterface): raise Exception(f"Failed to get neighbors: {error_msg}") from e async def get_nodeset_subgraph( - self, - node_type: Type[Any], - node_name: List[str] + self, node_type: Type[Any], node_name: List[str] ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: """ Fetch a subgraph consisting of a specific set of nodes and their relationships. @@ -987,10 +990,7 @@ class NeptuneGraphDB(GraphDBInterface): }}] AS rawRels """ - params = { - "names": node_name, - "type": node_type.__name__ - } + params = {"names": node_name, "type": node_type.__name__} result = await self.query(query, params) @@ -1002,18 +1002,14 @@ class NeptuneGraphDB(GraphDBInterface): raw_rels = result[0]["rawRels"] # Format nodes as (node_id, properties) tuples - nodes = [ - (n["id"], n["properties"]) - for n in raw_nodes - ] + nodes = [(n["id"], n["properties"]) for n in raw_nodes] # Format edges as (source_id, target_id, relationship_name, properties) tuples - edges = [ - (r["source_id"], r["target_id"], r["type"], r["properties"]) - for r in raw_rels - ] + edges = [(r["source_id"], r["target_id"], r["type"], r["properties"]) for r in raw_rels] - logger.debug(f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}") + logger.debug( + f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}" + ) return (nodes, edges) except Exception as e: @@ -1055,18 +1051,12 @@ class NeptuneGraphDB(GraphDBInterface): # Return as (source_node, relationship, target_node) connections.append( ( - { - "id": record["source_id"], - **record["source_props"] - }, + {"id": record["source_id"], **record["source_props"]}, { "relationship_name": record["relationship_name"], - **record["relationship_props"] + **record["relationship_props"], }, - { - "id": record["target_id"], - **record["target_props"] - } + {"id": record["target_id"], **record["target_props"]}, ) ) @@ -1078,10 +1068,7 @@ class NeptuneGraphDB(GraphDBInterface): logger.error(f"Failed to get connections for node {node_id}: {error_msg}") raise Exception(f"Failed to get connections: {error_msg}") from e - - async def remove_connection_to_predecessors_of( - self, node_ids: list[str], edge_label: str - ): + async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str): """ Remove connections (edges) to all predecessors of specified nodes based on edge label. @@ -1101,10 +1088,7 @@ class NeptuneGraphDB(GraphDBInterface): params = {"node_ids": node_ids} await self.query(query, params) - - async def remove_connection_to_successors_of( - self, node_ids: list[str], edge_label: str - ): + async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str): """ Remove connections (edges) to all successors of specified nodes based on edge label. @@ -1124,7 +1108,6 @@ class NeptuneGraphDB(GraphDBInterface): params = {"node_ids": node_ids} await self.query(query, params) - async def get_node_labels_string(self): """ Fetch all node labels from the database and return them as a formatted string. @@ -1138,7 +1121,9 @@ class NeptuneGraphDB(GraphDBInterface): ------- ValueError: If no node labels are found in the database. """ - node_labels_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels " + node_labels_query = ( + "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels " + ) node_labels_result = await self.query(node_labels_query) node_labels = node_labels_result[0]["labels"] if node_labels_result else [] @@ -1156,7 +1141,9 @@ class NeptuneGraphDB(GraphDBInterface): A formatted string of relationship types. """ - relationship_types_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships " + relationship_types_query = ( + "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships " + ) relationship_types_result = await self.query(relationship_types_query) relationship_types = ( relationship_types_result[0]["relationships"] if relationship_types_result else [] @@ -1172,7 +1159,6 @@ class NeptuneGraphDB(GraphDBInterface): ) return relationship_types_undirected_str - async def drop_graph(self, graph_name="myGraph"): """ Drop an existing graph from the database based on its name. @@ -1208,7 +1194,6 @@ class NeptuneGraphDB(GraphDBInterface): """ pass - async def project_entire_graph(self, graph_name="myGraph"): """ Project all node labels and relationship types into an in-memory graph using GDS. @@ -1280,7 +1265,6 @@ class NeptuneGraphDB(GraphDBInterface): return (nodes, edges) - async def get_degree_one_nodes(self, node_type: str): """ Fetch nodes of a specified type that have exactly one connection. @@ -1361,8 +1345,6 @@ class NeptuneGraphDB(GraphDBInterface): result = await self.query(query, {"content_hash": content_hash}) return result[0] if result else None - - async def _get_model_independent_graph_data(self): """ Retrieve the basic graph data without considering the model specifics, returning nodes @@ -1380,8 +1362,8 @@ class NeptuneGraphDB(GraphDBInterface): RETURN nodeCount AS numVertices, count(r) AS numEdges """ query_response = await self.query(query_string) - num_nodes = query_response[0].get('numVertices') - num_edges = query_response[0].get('numEdges') + num_nodes = query_response[0].get("numVertices") + num_edges = query_response[0].get("numEdges") return (num_nodes, num_edges) @@ -1437,4 +1419,9 @@ class NeptuneGraphDB(GraphDBInterface): @staticmethod def _convert_relationship_to_edge(relationship: dict) -> EdgeData: - return relationship["source_id"], relationship["target_id"], relationship["relationship_name"], relationship["properties"] + return ( + relationship["source_id"], + relationship["target_id"], + relationship["relationship_name"], + relationship["properties"], + ) diff --git a/cognee/infrastructure/databases/graph/neptune_driver/exceptions.py b/cognee/infrastructure/databases/graph/neptune_driver/exceptions.py index c5b5ec972..57d54d74d 100644 --- a/cognee/infrastructure/databases/graph/neptune_driver/exceptions.py +++ b/cognee/infrastructure/databases/graph/neptune_driver/exceptions.py @@ -2,106 +2,114 @@ This module defines custom exceptions for Neptune Analytics operations. """ + from cognee.exceptions import CogneeApiError from fastapi import status class NeptuneAnalyticsError(CogneeApiError): """Base exception for Neptune Analytics operations.""" + def __init__( self, message: str = "Neptune Analytics error.", name: str = "NeptuneAnalyticsError", - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ): super().__init__(message, name, status_code) - class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError): """Exception raised when connection to Neptune Analytics fails.""" + def __init__( self, message: str = "Unable to connect to Neptune Analytics. Please check the endpoint and network connectivity.", name: str = "NeptuneAnalyticsConnectionError", - status_code=status.HTTP_404_NOT_FOUND + status_code=status.HTTP_404_NOT_FOUND, ): super().__init__(message, name, status_code) class NeptuneAnalyticsQueryError(NeptuneAnalyticsError): """Exception raised when a query execution fails.""" + def __init__( self, message: str = "The query execution failed due to invalid syntax or semantic issues.", name: str = "NeptuneAnalyticsQueryError", - status_code=status.HTTP_400_BAD_REQUEST + status_code=status.HTTP_400_BAD_REQUEST, ): super().__init__(message, name, status_code) class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError): """Exception raised when authentication with Neptune Analytics fails.""" + def __init__( self, message: str = "Authentication with Neptune Analytics failed. Please verify your credentials.", name: str = "NeptuneAnalyticsAuthenticationError", - status_code=status.HTTP_401_UNAUTHORIZED + status_code=status.HTTP_401_UNAUTHORIZED, ): super().__init__(message, name, status_code) class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError): """Exception raised when Neptune Analytics configuration is invalid.""" + def __init__( self, message: str = "Neptune Analytics configuration is invalid or incomplete. Please review your setup.", name: str = "NeptuneAnalyticsConfigurationError", - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ): super().__init__(message, name, status_code) class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError): """Exception raised when a Neptune Analytics operation times out.""" + def __init__( self, message: str = "The operation timed out while communicating with Neptune Analytics.", name: str = "NeptuneAnalyticsTimeoutError", - status_code=status.HTTP_504_GATEWAY_TIMEOUT + status_code=status.HTTP_504_GATEWAY_TIMEOUT, ): super().__init__(message, name, status_code) class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError): """Exception raised when requests are throttled by Neptune Analytics.""" + def __init__( self, message: str = "Request was throttled by Neptune Analytics due to exceeding rate limits.", name: str = "NeptuneAnalyticsThrottlingError", - status_code=status.HTTP_429_TOO_MANY_REQUESTS + status_code=status.HTTP_429_TOO_MANY_REQUESTS, ): super().__init__(message, name, status_code) class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError): """Exception raised when a Neptune Analytics resource is not found.""" + def __init__( self, message: str = "The requested Neptune Analytics resource could not be found.", name: str = "NeptuneAnalyticsResourceNotFoundError", - status_code=status.HTTP_404_NOT_FOUND + status_code=status.HTTP_404_NOT_FOUND, ): super().__init__(message, name, status_code) class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError): """Exception raised when invalid parameters are provided to Neptune Analytics.""" + def __init__( self, message: str = "One or more parameters provided to Neptune Analytics are invalid or missing.", name: str = "NeptuneAnalyticsInvalidParameterError", - status_code=status.HTTP_400_BAD_REQUEST + status_code=status.HTTP_400_BAD_REQUEST, ): super().__init__(message, name, status_code) - diff --git a/cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py b/cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py index b70f2b1fa..0f71bd4e9 100644 --- a/cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +++ b/cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py @@ -16,40 +16,42 @@ logger = get_logger("NeptuneUtils") def parse_neptune_url(url: str) -> Tuple[str, str]: """ Parse a Neptune Analytics URL to extract graph ID and region. - + Expected format: neptune-graph://?region= or neptune-graph:// (defaults to us-east-1) - + Parameters: ----------- - url (str): The Neptune Analytics URL to parse - + Returns: -------- - Tuple[str, str]: A tuple containing (graph_id, region) - + Raises: ------- - ValueError: If the URL format is invalid """ try: parsed = urlparse(url) - + if parsed.scheme != "neptune-graph": raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'") - - graph_id = parsed.hostname or parsed.path.lstrip('/') + + graph_id = parsed.hostname or parsed.path.lstrip("/") if not graph_id: raise ValueError("Graph ID not found in URL") - + # Extract region from query parameters region = "us-east-1" # default region if parsed.query: - query_params = dict(param.split('=') for param in parsed.query.split('&') if '=' in param) - region = query_params.get('region', region) - + query_params = dict( + param.split("=") for param in parsed.query.split("&") if "=" in param + ) + region = query_params.get("region", region) + return graph_id, region - + except Exception as e: raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}") @@ -57,43 +59,43 @@ def parse_neptune_url(url: str) -> Tuple[str, str]: def validate_graph_id(graph_id: str) -> bool: """ Validate a Neptune Analytics graph ID format. - + Graph IDs should follow AWS naming conventions. - + Parameters: ----------- - graph_id (str): The graph ID to validate - + Returns: -------- - bool: True if the graph ID is valid, False otherwise """ if not graph_id: return False - + # Neptune Analytics graph IDs should be alphanumeric with hyphens # and between 1-63 characters - pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$' + pattern = r"^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$" return bool(re.match(pattern, graph_id)) def validate_aws_region(region: str) -> bool: """ Validate an AWS region format. - + Parameters: ----------- - region (str): The AWS region to validate - + Returns: -------- - bool: True if the region format is valid, False otherwise """ if not region: return False - + # AWS regions follow the pattern: us-east-1, eu-west-1, etc. - pattern = r'^[a-z]{2,3}-[a-z]+-\d+$' + pattern = r"^[a-z]{2,3}-[a-z]+-\d+$" return bool(re.match(pattern, region)) @@ -103,11 +105,11 @@ def build_neptune_config( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, - **kwargs + **kwargs, ) -> Dict[str, Any]: """ Build a configuration dictionary for Neptune Analytics connection. - + Parameters: ----------- - graph_id (str): The Neptune Analytics graph identifier @@ -116,11 +118,11 @@ def build_neptune_config( - aws_secret_access_key (Optional[str]): AWS secret access key - aws_session_token (Optional[str]): AWS session token for temporary credentials - **kwargs: Additional configuration parameters - + Returns: -------- - Dict[str, Any]: Configuration dictionary for Neptune Analytics - + Raises: ------- - ValueError: If required parameters are invalid @@ -129,35 +131,35 @@ def build_neptune_config( "graph_id": graph_id, "service_name": "neptune-graph", } - + # Add AWS credentials if provided if region: config["region"] = region if aws_access_key_id: config["aws_access_key_id"] = aws_access_key_id - + if aws_secret_access_key: config["aws_secret_access_key"] = aws_secret_access_key - + if aws_session_token: config["aws_session_token"] = aws_session_token - + # Add any additional configuration config.update(kwargs) - + return config def get_neptune_endpoint_url(graph_id: str, region: str) -> str: """ Construct the Neptune Analytics endpoint URL for a given graph and region. - + Parameters: ----------- - graph_id (str): The Neptune Analytics graph identifier - region (str): AWS region where the graph is located - + Returns: -------- - str: The Neptune Analytics endpoint URL @@ -168,17 +170,17 @@ def get_neptune_endpoint_url(graph_id: str, region: str) -> str: def format_neptune_error(error: Exception) -> str: """ Format Neptune Analytics specific errors for better readability. - + Parameters: ----------- - error (Exception): The exception to format - + Returns: -------- - str: Formatted error message """ error_msg = str(error) - + # Common Neptune Analytics error patterns and their user-friendly messages error_mappings = { "AccessDenied": "Access denied. Please check your AWS credentials and permissions.", @@ -187,17 +189,18 @@ def format_neptune_error(error: Exception) -> str: "ThrottlingException": "Request was throttled. Please retry with exponential backoff.", "InternalServerError": "Internal server error occurred. Please try again later.", } - + for error_type, friendly_msg in error_mappings.items(): if error_type in error_msg: return f"{friendly_msg} Original error: {error_msg}" - + return error_msg + def get_default_query_timeout() -> int: """ Get the default query timeout for Neptune Analytics operations. - + Returns: -------- - int: Default timeout in seconds @@ -208,7 +211,7 @@ def get_default_query_timeout() -> int: def get_default_connection_config() -> Dict[str, Any]: """ Get default connection configuration for Neptune Analytics. - + Returns: -------- - Dict[str, Any]: Default connection configuration diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py index b48bae773..a04e6f09e 100644 --- a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +++ b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py @@ -17,6 +17,7 @@ from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredRes logger = get_logger("NeptuneAnalyticsAdapter") + class IndexSchema(DataPoint): """ Represents a schema for an index data point containing an ID and text. @@ -27,16 +28,19 @@ class IndexSchema(DataPoint): - metadata: A dictionary with default index fields for the schema, currently configured to include 'text'. """ + id: str text: str metadata: dict = {"index_fields": ["text"]} + NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://" + class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): """ Hybrid adapter that combines Neptune Analytics Vector and Graph functionality. - + This adapter extends NeptuneGraphDB and implements VectorDBInterface to provide a unified interface for working with Neptune Analytics as both a vector store and a graph database. @@ -48,14 +52,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): _TOPK_UPPER_BOUND = 10 def __init__( - self, - graph_id: str, - embedding_engine: Optional[EmbeddingEngine] = None, - region: Optional[str] = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - ): + self, + graph_id: str, + embedding_engine: Optional[EmbeddingEngine] = None, + region: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): """ Initialize the Neptune Analytics hybrid adapter. @@ -74,12 +78,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): region=region, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token + aws_session_token=aws_session_token, ) - + # Add vector-specific attributes self.embedding_engine = embedding_engine - logger.info(f"Initialized Neptune Analytics hybrid adapter for graph: \"{graph_id}\" in region: \"{self.region}\"") + logger.info( + f'Initialized Neptune Analytics hybrid adapter for graph: "{graph_id}" in region: "{self.region}"' + ) # VectorDBInterface methods implementation @@ -161,7 +167,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): # Fetch embeddings texts = [DataPoint.get_embeddable_data(t) for t in data_points] - data_vectors = (await self.embedding_engine.embed_text(texts)) + data_vectors = await self.embedding_engine.embed_text(texts) for index, data_point in enumerate(data_points): node_id = data_point.id @@ -172,23 +178,24 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): properties = self._serialize_properties(data_point.model_dump()) properties[self._COLLECTION_PREFIX] = collection_name params = dict( - node_id = str(node_id), - properties = properties, - embedding = data_vector, - collection_name = collection_name + node_id=str(node_id), + properties=properties, + embedding=data_vector, + collection_name=collection_name, ) # Compose the query and send query_string = ( - f"MERGE (n " - f":{self._VECTOR_NODE_LABEL} " - f" {{`~id`: $node_id}}) " - f"ON CREATE SET n = $properties, n.updated_at = timestamp() " - f"ON MATCH SET n += $properties, n.updated_at = timestamp() " - f"WITH n, $embedding AS embedding " - f"CALL neptune.algo.vectors.upsert(n, embedding) " - f"YIELD success " - f"RETURN success ") + f"MERGE (n " + f":{self._VECTOR_NODE_LABEL} " + f" {{`~id`: $node_id}}) " + f"ON CREATE SET n = $properties, n.updated_at = timestamp() " + f"ON MATCH SET n += $properties, n.updated_at = timestamp() " + f"WITH n, $embedding AS embedding " + f"CALL neptune.algo.vectors.upsert(n, embedding) " + f"YIELD success " + f"RETURN success " + ) try: self._client.query(query_string, params) @@ -208,10 +215,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): """ # Do the fetch for each node params = dict(node_ids=data_point_ids, collection_name=collection_name) - query_string = (f"MATCH( n :{self._VECTOR_NODE_LABEL}) " - f"WHERE id(n) in $node_ids AND " - f"n.{self._COLLECTION_PREFIX} = $collection_name " - f"RETURN n as payload ") + query_string = ( + f"MATCH( n :{self._VECTOR_NODE_LABEL}) " + f"WHERE id(n) in $node_ids AND " + f"n.{self._COLLECTION_PREFIX} = $collection_name " + f"RETURN n as payload " + ) try: result = self._client.query(query_string, params) @@ -259,7 +268,8 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND: logger.warning( "Provided limit (%s) is invalid (zero, negative, or exceeds maximum). " - "Defaulting to limit=10.", limit + "Defaulting to limit=10.", + limit, ) limit = self._TOPK_UPPER_BOUND @@ -272,7 +282,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): elif query_vector: embedding = query_vector else: - data_vectors = (await self.embedding_engine.embed_text([query_text])) + data_vectors = await self.embedding_engine.embed_text([query_text]) embedding = data_vectors[0] # Compose the parameters map @@ -305,9 +315,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): try: query_response = self._client.query(query_string, params) - return [self._get_scored_result( - item = item, with_score = True - ) for item in query_response] + return [self._get_scored_result(item=item, with_score=True) for item in query_response] except Exception as e: self._na_exception_handler(e, query_string) @@ -332,11 +340,13 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): self._validate_embedding_engine() # Convert text to embedding array in batch - data_vectors = (await self.embedding_engine.embed_text(query_texts)) - return await asyncio.gather(*[ - self.search(collection_name, None, vector, limit, with_vectors) - for vector in data_vectors - ]) + data_vectors = await self.embedding_engine.embed_text(query_texts) + return await asyncio.gather( + *[ + self.search(collection_name, None, vector, limit, with_vectors) + for vector in data_vectors + ] + ) async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): """ @@ -350,10 +360,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): - data_point_ids (list[str]): A list of IDs of the data points to delete. """ params = dict(node_ids=data_point_ids, collection_name=collection_name) - query_string = (f"MATCH (n :{self._VECTOR_NODE_LABEL}) " - f"WHERE id(n) IN $node_ids " - f"AND n.{self._COLLECTION_PREFIX} = $collection_name " - f"DETACH DELETE n") + query_string = ( + f"MATCH (n :{self._VECTOR_NODE_LABEL}) " + f"WHERE id(n) IN $node_ids " + f"AND n.{self._COLLECTION_PREFIX} = $collection_name " + f"DETACH DELETE n" + ) try: self._client.query(query_string, params) except Exception as e: @@ -370,7 +382,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): await self.create_collection(f"{index_name}_{index_property_name}") async def index_data_points( - self, index_name: str, index_property_name: str, data_points: list[DataPoint] + self, index_name: str, index_property_name: str, data_points: list[DataPoint] ): """ Indexes a list of data points into Neptune Analytics by creating them as nodes. @@ -402,29 +414,28 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): Remove obsolete or unnecessary data from the database. """ # Run actual truncate - self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) " - f"DETACH DELETE n") + self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n") pass @staticmethod - def _get_scored_result(item: dict, with_vector: bool = False, with_score: bool = False) -> ScoredResult: + def _get_scored_result( + item: dict, with_vector: bool = False, with_score: bool = False + ) -> ScoredResult: """ Util method to simplify the object creation of ScoredResult base on incoming NX payload response. """ return ScoredResult( - id=item.get('payload').get('~id'), - payload=item.get('payload').get('~properties'), - score=item.get('score') if with_score else 0, - vector=item.get('embedding') if with_vector else None + id=item.get("payload").get("~id"), + payload=item.get("payload").get("~properties"), + score=item.get("score") if with_score else 0, + vector=item.get("embedding") if with_vector else None, ) def _na_exception_handler(self, ex, query_string: str): """ Generic exception handler for NA langchain. """ - logger.error( - "Neptune Analytics query failed: %s | Query: [%s]", ex, query_string - ) + logger.error("Neptune Analytics query failed: %s | Query: [%s]", ex, query_string) raise ex def _validate_embedding_engine(self): @@ -433,4 +444,6 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): :raises: ValueError if this object does not have a valid embedding_engine """ if self.embedding_engine is None: - raise ValueError("Neptune Analytics requires an embedder defined to make vector operations") + raise ValueError( + "Neptune Analytics requires an embedder defined to make vector operations" + ) diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 7c335e6f7..1604b6a59 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -125,9 +125,15 @@ def create_vector_engine( if not vector_db_url: raise EnvironmentError("Missing Neptune endpoint.") - from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL + from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import ( + NeptuneAnalyticsAdapter, + NEPTUNE_ANALYTICS_ENDPOINT_URL, + ) + if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL): - raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}'") + raise ValueError( + f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}'" + ) graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "") diff --git a/cognee/tests/test_neptune_analytics_graph.py b/cognee/tests/test_neptune_analytics_graph.py index 396dca6e2..c74f3f657 100644 --- a/cognee/tests/test_neptune_analytics_graph.py +++ b/cognee/tests/test_neptune_analytics_graph.py @@ -8,7 +8,7 @@ from cognee.modules.data.processing.document_types import TextDocument # Set up Amazon credentials in .env file and get the values from environment variables load_dotenv() -graph_id = os.getenv('GRAPH_ID', "") +graph_id = os.getenv("GRAPH_ID", "") na_adapter = NeptuneGraphDB(graph_id) @@ -24,33 +24,33 @@ def setup(): # stored in Amazon S3. document = TextDocument( - name='text_test.txt', - raw_data_location='git/cognee/examples/database_examples/data_storage/data/text_test.txt', - external_metadata='{}', - mime_type='text/plain' + name="text_test.txt", + raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt", + external_metadata="{}", + mime_type="text/plain", ) document_chunk = DocumentChunk( text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", chunk_size=187, chunk_index=0, - cut_type='paragraph_end', + cut_type="paragraph_end", is_part_of=document, ) - graph_database = EntityType(name='graph database', description='graph database') + graph_database = EntityType(name="graph database", description="graph database") neptune_analytics_entity = Entity( - name='neptune analytics', - description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.', + name="neptune analytics", + description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.", ) neptune_database_entity = Entity( - name='amazon neptune database', - description='A popular managed graph database that complements Neptune Analytics.', + name="amazon neptune database", + description="A popular managed graph database that complements Neptune Analytics.", ) - storage = EntityType(name='storage', description='storage') + storage = EntityType(name="storage", description="storage") storage_entity = Entity( - name='amazon s3', - description='A storage service provided by Amazon Web Services that allows storing graph data.', + name="amazon s3", + description="A storage service provided by Amazon Web Services that allows storing graph data.", ) nodes_data = [ @@ -67,37 +67,37 @@ def setup(): ( str(document_chunk.id), str(storage_entity.id), - 'contains', + "contains", ), ( str(storage_entity.id), str(storage.id), - 'is_a', + "is_a", ), ( str(document_chunk.id), str(neptune_database_entity.id), - 'contains', + "contains", ), ( str(neptune_database_entity.id), str(graph_database.id), - 'is_a', + "is_a", ), ( str(document_chunk.id), str(document.id), - 'is_part_of', + "is_part_of", ), ( str(document_chunk.id), str(neptune_analytics_entity.id), - 'contains', + "contains", ), ( str(neptune_analytics_entity.id), str(graph_database.id), - 'is_a', + "is_a", ), ] @@ -155,42 +155,44 @@ async def pipeline_method(): print("------NEIGHBORING NODES-------") center_node = nodes[2] neighbors = await na_adapter.get_neighbors(str(center_node.id)) - print(f"found {len(neighbors)} neighbors for node \"{center_node.name}\"") + print(f'found {len(neighbors)} neighbors for node "{center_node.name}"') for neighbor in neighbors: print(neighbor) print("------NEIGHBORING EDGES-------") center_node = nodes[2] neighbouring_edges = await na_adapter.get_edges(str(center_node.id)) - print(f"found {len(neighbouring_edges)} edges neighbouring node \"{center_node.name}\"") + print(f'found {len(neighbouring_edges)} edges neighbouring node "{center_node.name}"') for edge in neighbouring_edges: print(edge) print("------GET CONNECTIONS (SOURCE NODE)-------") document_chunk_node = nodes[0] connections = await na_adapter.get_connections(str(document_chunk_node.id)) - print(f"found {len(connections)} connections for node \"{document_chunk_node.type}\"") + print(f'found {len(connections)} connections for node "{document_chunk_node.type}"') for connection in connections: src, relationship, tgt = connection src = src.get("name", src.get("type", "unknown")) relationship = relationship["relationship_name"] tgt = tgt.get("name", tgt.get("type", "unknown")) - print(f"\"{src}\"-[{relationship}]->\"{tgt}\"") + print(f'"{src}"-[{relationship}]->"{tgt}"') print("------GET CONNECTIONS (TARGET NODE)-------") connections = await na_adapter.get_connections(str(center_node.id)) - print(f"found {len(connections)} connections for node \"{center_node.name}\"") + print(f'found {len(connections)} connections for node "{center_node.name}"') for connection in connections: src, relationship, tgt = connection src = src.get("name", src.get("type", "unknown")) relationship = relationship["relationship_name"] tgt = tgt.get("name", tgt.get("type", "unknown")) - print(f"\"{src}\"-[{relationship}]->\"{tgt}\"") + print(f'"{src}"-[{relationship}]->"{tgt}"') print("------SUBGRAPH-------") node_names = ["neptune analytics", "amazon neptune database"] subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names) - print(f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}") + print( + f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}" + ) for subgraph_node in subgraph_nodes: print(subgraph_node) for subgraph_edge in subgraph_edges: @@ -199,17 +201,17 @@ async def pipeline_method(): print("------STAT-------") stat = await na_adapter.get_graph_metrics(include_optional=True) assert type(stat) is dict - assert stat['num_nodes'] == 7 - assert stat['num_edges'] == 7 - assert stat['mean_degree'] == 2.0 - assert round(stat['edge_density'], 3) == 0.167 - assert stat['num_connected_components'] == [7] - assert stat['sizes_of_connected_components'] == 1 - assert stat['num_selfloops'] == 0 + assert stat["num_nodes"] == 7 + assert stat["num_edges"] == 7 + assert stat["mean_degree"] == 2.0 + assert round(stat["edge_density"], 3) == 0.167 + assert stat["num_connected_components"] == [7] + assert stat["sizes_of_connected_components"] == 1 + assert stat["num_selfloops"] == 0 # Unsupported optional metrics - assert stat['diameter'] == -1 - assert stat['avg_shortest_path_length'] == -1 - assert stat['avg_clustering'] == -1 + assert stat["diameter"] == -1 + assert stat["avg_shortest_path_length"] == -1 + assert stat["avg_clustering"] == -1 print("------DELETE-------") # delete all nodes and edges: @@ -253,7 +255,9 @@ async def misc_methods(): print(edge_labels) print("------Get Filtered Graph-------") - filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data([{'name': ['text_test.txt']}]) + filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data( + [{"name": ["text_test.txt"]}] + ) print(filtered_nodes, filtered_edges) print("------Get Degree one nodes-------") @@ -261,15 +265,13 @@ async def misc_methods(): print(degree_one_nodes) print("------Get Doc sub-graph-------") - doc_sub_graph = await na_adapter.get_document_subgraph('test.txt') + doc_sub_graph = await na_adapter.get_document_subgraph("test.txt") print(doc_sub_graph) print("------Fetch and Remove connections (Predecessors)-------") # Fetch test edge (src_id, dest_id, relationship) = edges[0] - nodes_predecessors = await na_adapter.get_predecessors( - node_id=dest_id, edge_label=relationship - ) + nodes_predecessors = await na_adapter.get_predecessors(node_id=dest_id, edge_label=relationship) assert len(nodes_predecessors) > 0 await na_adapter.remove_connection_to_predecessors_of( @@ -281,25 +283,19 @@ async def misc_methods(): # Return empty after relationship being deleted. assert len(nodes_predecessors_after) == 0 - print("------Fetch and Remove connections (Successors)-------") _, edges_suc = await na_adapter.get_graph_data() (src_id, dest_id, relationship, _) = edges_suc[0] - nodes_successors = await na_adapter.get_successors( - node_id=src_id, edge_label=relationship - ) + nodes_successors = await na_adapter.get_successors(node_id=src_id, edge_label=relationship) assert len(nodes_successors) > 0 - await na_adapter.remove_connection_to_successors_of( - node_ids=[dest_id], edge_label=relationship - ) + await na_adapter.remove_connection_to_successors_of(node_ids=[dest_id], edge_label=relationship) nodes_successors_after = await na_adapter.get_successors( node_id=src_id, edge_label=relationship ) assert len(nodes_successors_after) == 0 - # no-op await na_adapter.project_entire_graph() await na_adapter.drop_graph() diff --git a/cognee/tests/test_neptune_analytics_hybrid.py b/cognee/tests/test_neptune_analytics_hybrid.py index 352d20fba..5999acace 100644 --- a/cognee/tests/test_neptune_analytics_hybrid.py +++ b/cognee/tests/test_neptune_analytics_hybrid.py @@ -8,11 +8,13 @@ from cognee.modules.engine.models import Entity, EntityType from cognee.modules.data.processing.document_types import TextDocument from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.shared.logging_utils import get_logger -from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter +from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import ( + NeptuneAnalyticsAdapter, +) # Set up Amazon credentials in .env file and get the values from environment variables load_dotenv() -graph_id = os.getenv('GRAPH_ID', "") +graph_id = os.getenv("GRAPH_ID", "") # get the default embedder embedding_engine = get_embedding_engine() @@ -23,6 +25,7 @@ collection = "test_collection" logger = get_logger("test_neptune_analytics_hybrid") + def setup_data(): # Define nodes data before the main function # These nodes were defined using openAI from the following prompt: @@ -34,35 +37,35 @@ def setup_data(): # stored in Amazon S3. document = TextDocument( - name='text.txt', - raw_data_location='git/cognee/examples/database_examples/data_storage/data/text.txt', - external_metadata='{}', - mime_type='text/plain' + name="text.txt", + raw_data_location="git/cognee/examples/database_examples/data_storage/data/text.txt", + external_metadata="{}", + mime_type="text/plain", ) document_chunk = DocumentChunk( text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", chunk_size=187, chunk_index=0, - cut_type='paragraph_end', + cut_type="paragraph_end", is_part_of=document, ) - graph_database = EntityType(name='graph database', description='graph database') + graph_database = EntityType(name="graph database", description="graph database") neptune_analytics_entity = Entity( - name='neptune analytics', - description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.', + name="neptune analytics", + description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.", ) neptune_database_entity = Entity( - name='amazon neptune database', - description='A popular managed graph database that complements Neptune Analytics.', + name="amazon neptune database", + description="A popular managed graph database that complements Neptune Analytics.", ) - storage = EntityType(name='storage', description='storage') + storage = EntityType(name="storage", description="storage") storage_entity = Entity( - name='amazon s3', - description='A storage service provided by Amazon Web Services that allows storing graph data.', + name="amazon s3", + description="A storage service provided by Amazon Web Services that allows storing graph data.", ) - + nodes_data = [ document, document_chunk, @@ -77,41 +80,42 @@ def setup_data(): ( str(document_chunk.id), str(storage_entity.id), - 'contains', + "contains", ), ( str(storage_entity.id), str(storage.id), - 'is_a', + "is_a", ), ( str(document_chunk.id), str(neptune_database_entity.id), - 'contains', + "contains", ), ( str(neptune_database_entity.id), str(graph_database.id), - 'is_a', + "is_a", ), ( str(document_chunk.id), str(document.id), - 'is_part_of', + "is_part_of", ), ( str(document_chunk.id), str(neptune_analytics_entity.id), - 'contains', + "contains", ), ( str(neptune_analytics_entity.id), str(graph_database.id), - 'is_a', + "is_a", ), ] return nodes_data, edges_data + async def test_add_graph_then_vector_data(): logger.info("------test_add_graph_then_vector_data-------") (nodes, edges) = setup_data() @@ -134,6 +138,7 @@ async def test_add_graph_then_vector_data(): assert len(edges) == 0 logger.info("------PASSED-------") + async def test_add_vector_then_node_data(): logger.info("------test_add_vector_then_node_data-------") (nodes, edges) = setup_data() @@ -156,6 +161,7 @@ async def test_add_vector_then_node_data(): assert len(edges) == 0 logger.info("------PASSED-------") + def main(): """ Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data @@ -165,5 +171,6 @@ def main(): asyncio.run(test_add_graph_then_vector_data()) asyncio.run(test_add_vector_then_node_data()) + if __name__ == "__main__": main() diff --git a/cognee/tests/test_neptune_analytics_vector.py b/cognee/tests/test_neptune_analytics_vector.py index e8b5790a5..eececacdd 100644 --- a/cognee/tests/test_neptune_analytics_vector.py +++ b/cognee/tests/test_neptune_analytics_vector.py @@ -8,13 +8,16 @@ from cognee.modules.users.methods import get_default_user from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, IndexSchema +from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import ( + NeptuneAnalyticsAdapter, + IndexSchema, +) logger = get_logger() async def main(): - graph_id = os.getenv('GRAPH_ID', "") + graph_id = os.getenv("GRAPH_ID", "") cognee.config.set_vector_db_provider("neptune_analytics") cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}") data_directory_path = str( @@ -87,6 +90,7 @@ async def main(): await cognee.prune.prune_system(metadata=True) + async def vector_backend_api_test(): cognee.config.set_vector_db_provider("neptune_analytics") @@ -101,7 +105,7 @@ async def vector_backend_api_test(): get_vector_engine() # Return a valid engine object with valid URL. - graph_id = os.getenv('GRAPH_ID', "") + graph_id = os.getenv("GRAPH_ID", "") cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}") engine = get_vector_engine() assert isinstance(engine, NeptuneAnalyticsAdapter) @@ -133,28 +137,22 @@ async def vector_backend_api_test(): query_text=TEST_TEXT, query_vector=None, limit=10, - with_vector=True) - assert (len(result_search) == 2) + with_vector=True, + ) + assert len(result_search) == 2 # # Retrieve data-points result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2]) - assert any( - str(r.id) == TEST_UUID and r.payload['text'] == TEST_TEXT - for r in result - ) - assert any( - str(r.id) == TEST_UUID_2 and r.payload['text'] == TEST_TEXT_2 - for r in result - ) + assert any(str(r.id) == TEST_UUID and r.payload["text"] == TEST_TEXT for r in result) + assert any(str(r.id) == TEST_UUID_2 and r.payload["text"] == TEST_TEXT_2 for r in result) # Search multiple result_search_batch = await engine.batch_search( collection_name=TEST_COLLECTION_NAME, query_texts=[TEST_TEXT, TEST_TEXT_2], limit=10, - with_vectors=False + with_vectors=False, ) - assert (len(result_search_batch) == 2 and - all(len(batch) == 2 for batch in result_search_batch)) + assert len(result_search_batch) == 2 and all(len(batch) == 2 for batch in result_search_batch) # Delete datapoint from vector store await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2]) @@ -163,6 +161,7 @@ async def vector_backend_api_test(): result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID]) assert result_deleted == [] + if __name__ == "__main__": import asyncio diff --git a/examples/database_examples/neptune_analytics_example.py b/examples/database_examples/neptune_analytics_example.py index acc2baedb..5d36e2803 100644 --- a/examples/database_examples/neptune_analytics_example.py +++ b/examples/database_examples/neptune_analytics_example.py @@ -9,6 +9,7 @@ from dotenv import load_dotenv load_dotenv() + async def main(): """ Example script demonstrating how to use Cognee with Amazon Neptune Analytics @@ -22,7 +23,7 @@ async def main(): """ # Set up Amazon credentials in .env file and get the values from environment variables - graph_endpoint_url = "neptune-graph://" + os.getenv('GRAPH_ID', "") + graph_endpoint_url = "neptune-graph://" + os.getenv("GRAPH_ID", "") # Configure Neptune Analytics as the graph & vector database provider cognee.config.set_graph_db_config( @@ -77,7 +78,9 @@ async def main(): # Now let's perform some searches # 1. Search for insights related to "Neptune Analytics" - insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neptune Analytics") + insights_results = await cognee.search( + query_type=SearchType.INSIGHTS, query_text="Neptune Analytics" + ) print("\n========Insights about Neptune Analytics========:") for result in insights_results: print(f"- {result}") diff --git a/examples/python/weighted_graph_visualization.html b/examples/python/weighted_graph_visualization.html index 2e7f67e31..12424fbeb 100644 --- a/examples/python/weighted_graph_visualization.html +++ b/examples/python/weighted_graph_visualization.html @@ -37,8 +37,8 @@