added linting and formatting

This commit is contained in:
vasilije 2025-08-02 16:58:18 +02:00
parent ae8fc7f0c9
commit c0d60eef25
11 changed files with 355 additions and 326 deletions

View file

@ -149,7 +149,9 @@ def create_graph_engine(
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
raise ValueError(f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>")
raise ValueError(
f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
@ -174,10 +176,15 @@ def create_graph_engine(
if not graph_database_url:
raise EnvironmentError("Missing Neptune endpoint.")
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
NEPTUNE_ANALYTICS_ENDPOINT_URL,
)
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
raise ValueError(
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")

View file

@ -29,6 +29,7 @@ logger = get_logger("NeptuneGraphDB")
try:
from langchain_aws import NeptuneAnalyticsGraph
LANGCHAIN_AWS_AVAILABLE = True
except ImportError:
logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.")
@ -36,11 +37,13 @@ except ImportError:
NEPTUNE_ENDPOINT_URL = "neptune-graph://"
class NeptuneGraphDB(GraphDBInterface):
"""
Adapter for interacting with Amazon Neptune Analytics graph store.
This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library.
"""
_GRAPH_NODE_LABEL = "COGNEE_NODE"
def __init__(
@ -61,28 +64,30 @@ class NeptuneGraphDB(GraphDBInterface):
- aws_access_key_id (Optional[str]): AWS access key ID
- aws_secret_access_key (Optional[str]): AWS secret access key
- aws_session_token (Optional[str]): AWS session token for temporary credentials
Raises:
-------
- NeptuneAnalyticsConfigurationError: If configuration parameters are invalid
"""
# validate import
if not LANGCHAIN_AWS_AVAILABLE:
raise ImportError("langchain_aws is not available. Please install it to use Neptune Analytics.")
raise ImportError(
"langchain_aws is not available. Please install it to use Neptune Analytics."
)
# Validate configuration
if not validate_graph_id(graph_id):
raise NeptuneAnalyticsConfigurationError(message=f"Invalid graph ID: \"{graph_id}\"")
raise NeptuneAnalyticsConfigurationError(message=f'Invalid graph ID: "{graph_id}"')
if region and not validate_aws_region(region):
raise NeptuneAnalyticsConfigurationError(message=f"Invalid AWS region: \"{region}\"")
raise NeptuneAnalyticsConfigurationError(message=f'Invalid AWS region: "{region}"')
self.graph_id = graph_id
self.region = region
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
# Build configuration
self.config = build_neptune_config(
graph_id=self.graph_id,
@ -91,15 +96,17 @@ class NeptuneGraphDB(GraphDBInterface):
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
)
# Initialize Neptune Analytics client using langchain_aws
self._client: NeptuneAnalyticsGraph = self._initialize_client()
logger.info(f"Initialized Neptune Analytics adapter for graph: \"{graph_id}\" in region: \"{self.region}\"")
logger.info(
f'Initialized Neptune Analytics adapter for graph: "{graph_id}" in region: "{self.region}"'
)
def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]:
"""
Initialize the Neptune Analytics client using langchain_aws.
Returns:
--------
- Optional[Any]: The Neptune Analytics client or None if not available
@ -108,9 +115,7 @@ class NeptuneGraphDB(GraphDBInterface):
# Initialize the Neptune Analytics Graph client
client_config = {
"graph_identifier": self.graph_id,
"config": Config(
user_agent_appid='Cognee'
)
"config": Config(user_agent_appid="Cognee"),
}
# Add AWS credentials if provided
if self.region:
@ -121,13 +126,15 @@ class NeptuneGraphDB(GraphDBInterface):
client_config["aws_secret_access_key"] = self.aws_secret_access_key
if self.aws_session_token:
client_config["aws_session_token"] = self.aws_session_token
client = NeptuneAnalyticsGraph(**client_config)
logger.info("Successfully initialized Neptune Analytics client")
return client
except Exception as e:
raise NeptuneAnalyticsConfigurationError(message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}") from e
raise NeptuneAnalyticsConfigurationError(
message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}"
) from e
@staticmethod
def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]:
@ -175,7 +182,7 @@ class NeptuneGraphDB(GraphDBInterface):
params = {}
logger.debug(f"executing na query:\nquery={query}\n")
result = self._client.query(query, params)
# Convert the result to list format expected by the interface
if isinstance(result, list):
return result
@ -183,7 +190,7 @@ class NeptuneGraphDB(GraphDBInterface):
return [result]
else:
return [{"result": result}]
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Neptune Analytics query failed: {error_msg}")
@ -212,10 +219,11 @@ class NeptuneGraphDB(GraphDBInterface):
"node_id": str(node.id),
"properties": serialized_properties,
}
result = await self.query(query, params)
logger.debug(f"Successfully added/updated node: {node.id}")
logger.debug(f"Successfully gotten: {str(result)}")
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to add node {node.id}: {error_msg}")
@ -256,7 +264,7 @@ class NeptuneGraphDB(GraphDBInterface):
}
result = await self.query(query, params)
processed_count = result[0].get('nodes_processed', 0) if result else 0
processed_count = result[0].get("nodes_processed", 0) if result else 0
logger.debug(f"Successfully processed {processed_count} nodes in bulk operation")
except Exception as e:
@ -268,7 +276,9 @@ class NeptuneGraphDB(GraphDBInterface):
try:
await self.add_node(node)
except Exception as node_error:
logger.error(f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}")
logger.error(
f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}"
)
continue
async def delete_node(self, node_id: str) -> None:
@ -287,13 +297,11 @@ class NeptuneGraphDB(GraphDBInterface):
DETACH DELETE n
"""
params = {
"node_id": node_id
}
params = {"node_id": node_id}
await self.query(query, params)
logger.debug(f"Successfully deleted node: {node_id}")
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to delete node {node_id}: {error_msg}")
@ -333,7 +341,9 @@ class NeptuneGraphDB(GraphDBInterface):
try:
await self.delete_node(node_id)
except Exception as node_error:
logger.error(f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}")
logger.error(
f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}"
)
continue
async def get_node(self, node_id: str) -> Optional[NodeData]:
@ -355,10 +365,10 @@ class NeptuneGraphDB(GraphDBInterface):
WHERE id(n) = $node_id
RETURN n
"""
params = {'node_id': node_id}
params = {"node_id": node_id}
result = await self.query(query, params)
if result and len(result) == 1:
# Extract node properties from the result
logger.debug(f"Successfully retrieved node: {node_id}")
@ -369,7 +379,7 @@ class NeptuneGraphDB(GraphDBInterface):
elif len(result) > 1:
logger.debug(f"Only one node expected, multiple returned: {node_id}")
return None
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to get node {node_id}: {error_msg}")
@ -406,7 +416,9 @@ class NeptuneGraphDB(GraphDBInterface):
# Extract node data from results
nodes = [record["n"] for record in result]
logger.debug(f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested")
logger.debug(
f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested"
)
return nodes
except Exception as e:
@ -421,11 +433,12 @@ class NeptuneGraphDB(GraphDBInterface):
if node_data:
nodes.append(node_data)
except Exception as node_error:
logger.error(f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}")
logger.error(
f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}"
)
continue
return nodes
async def extract_node(self, node_id: str):
"""
Retrieve a single node based on its ID.
@ -512,8 +525,10 @@ class NeptuneGraphDB(GraphDBInterface):
"properties": serialized_properties,
}
await self.query(query, params)
logger.debug(f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}")
logger.debug(
f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}"
)
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}")
@ -557,20 +572,24 @@ class NeptuneGraphDB(GraphDBInterface):
"""
# Prepare edges data for bulk operation
params = {"edges":
[
params = {
"edges": [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": relationship_name,
"properties": self._serialize_properties(edge[3] if len(edge) > 3 and edge[3] else {}),
"properties": self._serialize_properties(
edge[3] if len(edge) > 3 and edge[3] else {}
),
}
for edge in edges_for_relationship
]
}
results[relationship_name] = await self.query(query, params)
except Exception as e:
logger.error(f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}")
logger.error(
f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}"
)
logger.info("Falling back to individual edge creation")
for edge in edges_by_relationship:
try:
@ -578,15 +597,16 @@ class NeptuneGraphDB(GraphDBInterface):
properties = edge[3] if len(edge) > 3 else {}
await self.add_edge(source_id, target_id, relationship_name, properties)
except Exception as edge_error:
logger.error(f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}")
logger.error(
f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}"
)
continue
processed_count = 0
for result in results.values():
processed_count += result[0].get('edges_processed', 0) if result else 0
processed_count += result[0].get("edges_processed", 0) if result else 0
logger.debug(f"Successfully processed {processed_count} edges in bulk operation")
async def delete_graph(self) -> None:
"""
Delete all nodes and edges from the graph database.
@ -599,14 +619,12 @@ class NeptuneGraphDB(GraphDBInterface):
# Build openCypher query to delete the graph
query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n"
await self.query(query)
logger.debug(f"Successfully deleted all edges and nodes from the graph")
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to delete graph: {error_msg}")
raise Exception(f"Failed to delete graph: {error_msg}") from e
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
"""
Retrieve all nodes and edges within the graph.
@ -633,13 +651,7 @@ class NeptuneGraphDB(GraphDBInterface):
edges_result = await self.query(edges_query)
# Format nodes as (node_id, properties) tuples
nodes = [
(
result["node_id"],
result["properties"]
)
for result in nodes_result
]
nodes = [(result["node_id"], result["properties"]) for result in nodes_result]
# Format edges as (source_id, target_id, relationship_name, properties) tuples
edges = [
@ -647,7 +659,7 @@ class NeptuneGraphDB(GraphDBInterface):
result["source_id"],
result["target_id"],
result["relationship_name"],
result["properties"]
result["properties"],
)
for result in edges_result
]
@ -679,9 +691,11 @@ class NeptuneGraphDB(GraphDBInterface):
"num_nodes": num_nodes,
"num_edges": num_edges,
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1)) if num_nodes != 0 else None,
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1))
if num_nodes != 0
else None,
"num_connected_components": num_cluster,
"sizes_of_connected_components": list_clsuter_size
"sizes_of_connected_components": list_clsuter_size,
}
optional_metrics = {
@ -692,7 +706,7 @@ class NeptuneGraphDB(GraphDBInterface):
}
if include_optional:
optional_metrics['num_selfloops'] = await self._count_self_loops()
optional_metrics["num_selfloops"] = await self._count_self_loops()
# Unsupported due to long-running queries when computing the shortest path for each node in the graph:
# optional_metrics['diameter']
# optional_metrics['avg_shortest_path_length']
@ -728,17 +742,19 @@ class NeptuneGraphDB(GraphDBInterface):
"source_id": source_id,
"target_id": target_id,
}
result = await self.query(query, params)
if result and len(result) > 0:
edge_exists = result.pop().get('edge_exists', False)
logger.debug(f"Edge existence check for "
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}")
edge_exists = result.pop().get("edge_exists", False)
logger.debug(
f"Edge existence check for "
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}"
)
return edge_exists
else:
return False
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to check edge existence {source_id} -> {target_id}: {error_msg}")
@ -778,7 +794,7 @@ class NeptuneGraphDB(GraphDBInterface):
results = await self.query(query, params)
logger.debug(f"Found {len(results)} existing edges out of {len(edges)} checked")
return [result["edge_exists"] for result in results]
except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to check edges existence: {error_msg}")
@ -863,10 +879,7 @@ class NeptuneGraphDB(GraphDBInterface):
RETURN predecessor
"""
results = await self.query(
query,
{"node_id": node_id}
)
results = await self.query(query, {"node_id": node_id})
return [result["predecessor"] for result in results]
@ -893,14 +906,10 @@ class NeptuneGraphDB(GraphDBInterface):
RETURN successor
"""
results = await self.query(
query,
{"node_id": node_id}
)
results = await self.query(query, {"node_id": node_id})
return [result["successor"] for result in results]
async def get_neighbors(self, node_id: str) -> List[NodeData]:
"""
Get all neighboring nodes connected to the specified node.
@ -926,11 +935,7 @@ class NeptuneGraphDB(GraphDBInterface):
# Format neighbors as NodeData objects
neighbors = [
{
"id": neighbor["neighbor_id"],
**neighbor["properties"]
}
for neighbor in result
{"id": neighbor["neighbor_id"], **neighbor["properties"]} for neighbor in result
]
logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}")
@ -942,9 +947,7 @@ class NeptuneGraphDB(GraphDBInterface):
raise Exception(f"Failed to get neighbors: {error_msg}") from e
async def get_nodeset_subgraph(
self,
node_type: Type[Any],
node_name: List[str]
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""
Fetch a subgraph consisting of a specific set of nodes and their relationships.
@ -987,10 +990,7 @@ class NeptuneGraphDB(GraphDBInterface):
}}] AS rawRels
"""
params = {
"names": node_name,
"type": node_type.__name__
}
params = {"names": node_name, "type": node_type.__name__}
result = await self.query(query, params)
@ -1002,18 +1002,14 @@ class NeptuneGraphDB(GraphDBInterface):
raw_rels = result[0]["rawRels"]
# Format nodes as (node_id, properties) tuples
nodes = [
(n["id"], n["properties"])
for n in raw_nodes
]
nodes = [(n["id"], n["properties"]) for n in raw_nodes]
# Format edges as (source_id, target_id, relationship_name, properties) tuples
edges = [
(r["source_id"], r["target_id"], r["type"], r["properties"])
for r in raw_rels
]
edges = [(r["source_id"], r["target_id"], r["type"], r["properties"]) for r in raw_rels]
logger.debug(f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}")
logger.debug(
f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}"
)
return (nodes, edges)
except Exception as e:
@ -1055,18 +1051,12 @@ class NeptuneGraphDB(GraphDBInterface):
# Return as (source_node, relationship, target_node)
connections.append(
(
{
"id": record["source_id"],
**record["source_props"]
},
{"id": record["source_id"], **record["source_props"]},
{
"relationship_name": record["relationship_name"],
**record["relationship_props"]
**record["relationship_props"],
},
{
"id": record["target_id"],
**record["target_props"]
}
{"id": record["target_id"], **record["target_props"]},
)
)
@ -1078,10 +1068,7 @@ class NeptuneGraphDB(GraphDBInterface):
logger.error(f"Failed to get connections for node {node_id}: {error_msg}")
raise Exception(f"Failed to get connections: {error_msg}") from e
async def remove_connection_to_predecessors_of(
self, node_ids: list[str], edge_label: str
):
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str):
"""
Remove connections (edges) to all predecessors of specified nodes based on edge label.
@ -1101,10 +1088,7 @@ class NeptuneGraphDB(GraphDBInterface):
params = {"node_ids": node_ids}
await self.query(query, params)
async def remove_connection_to_successors_of(
self, node_ids: list[str], edge_label: str
):
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str):
"""
Remove connections (edges) to all successors of specified nodes based on edge label.
@ -1124,7 +1108,6 @@ class NeptuneGraphDB(GraphDBInterface):
params = {"node_ids": node_ids}
await self.query(query, params)
async def get_node_labels_string(self):
"""
Fetch all node labels from the database and return them as a formatted string.
@ -1138,7 +1121,9 @@ class NeptuneGraphDB(GraphDBInterface):
-------
ValueError: If no node labels are found in the database.
"""
node_labels_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
node_labels_query = (
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
)
node_labels_result = await self.query(node_labels_query)
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
@ -1156,7 +1141,9 @@ class NeptuneGraphDB(GraphDBInterface):
A formatted string of relationship types.
"""
relationship_types_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
relationship_types_query = (
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
)
relationship_types_result = await self.query(relationship_types_query)
relationship_types = (
relationship_types_result[0]["relationships"] if relationship_types_result else []
@ -1172,7 +1159,6 @@ class NeptuneGraphDB(GraphDBInterface):
)
return relationship_types_undirected_str
async def drop_graph(self, graph_name="myGraph"):
"""
Drop an existing graph from the database based on its name.
@ -1208,7 +1194,6 @@ class NeptuneGraphDB(GraphDBInterface):
"""
pass
async def project_entire_graph(self, graph_name="myGraph"):
"""
Project all node labels and relationship types into an in-memory graph using GDS.
@ -1280,7 +1265,6 @@ class NeptuneGraphDB(GraphDBInterface):
return (nodes, edges)
async def get_degree_one_nodes(self, node_type: str):
"""
Fetch nodes of a specified type that have exactly one connection.
@ -1361,8 +1345,6 @@ class NeptuneGraphDB(GraphDBInterface):
result = await self.query(query, {"content_hash": content_hash})
return result[0] if result else None
async def _get_model_independent_graph_data(self):
"""
Retrieve the basic graph data without considering the model specifics, returning nodes
@ -1380,8 +1362,8 @@ class NeptuneGraphDB(GraphDBInterface):
RETURN nodeCount AS numVertices, count(r) AS numEdges
"""
query_response = await self.query(query_string)
num_nodes = query_response[0].get('numVertices')
num_edges = query_response[0].get('numEdges')
num_nodes = query_response[0].get("numVertices")
num_edges = query_response[0].get("numEdges")
return (num_nodes, num_edges)
@ -1437,4 +1419,9 @@ class NeptuneGraphDB(GraphDBInterface):
@staticmethod
def _convert_relationship_to_edge(relationship: dict) -> EdgeData:
return relationship["source_id"], relationship["target_id"], relationship["relationship_name"], relationship["properties"]
return (
relationship["source_id"],
relationship["target_id"],
relationship["relationship_name"],
relationship["properties"],
)

View file

@ -2,106 +2,114 @@
This module defines custom exceptions for Neptune Analytics operations.
"""
from cognee.exceptions import CogneeApiError
from fastapi import status
class NeptuneAnalyticsError(CogneeApiError):
"""Base exception for Neptune Analytics operations."""
def __init__(
self,
message: str = "Neptune Analytics error.",
name: str = "NeptuneAnalyticsError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError):
"""Exception raised when connection to Neptune Analytics fails."""
def __init__(
self,
message: str = "Unable to connect to Neptune Analytics. Please check the endpoint and network connectivity.",
name: str = "NeptuneAnalyticsConnectionError",
status_code=status.HTTP_404_NOT_FOUND
status_code=status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsQueryError(NeptuneAnalyticsError):
"""Exception raised when a query execution fails."""
def __init__(
self,
message: str = "The query execution failed due to invalid syntax or semantic issues.",
name: str = "NeptuneAnalyticsQueryError",
status_code=status.HTTP_400_BAD_REQUEST
status_code=status.HTTP_400_BAD_REQUEST,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError):
"""Exception raised when authentication with Neptune Analytics fails."""
def __init__(
self,
message: str = "Authentication with Neptune Analytics failed. Please verify your credentials.",
name: str = "NeptuneAnalyticsAuthenticationError",
status_code=status.HTTP_401_UNAUTHORIZED
status_code=status.HTTP_401_UNAUTHORIZED,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError):
"""Exception raised when Neptune Analytics configuration is invalid."""
def __init__(
self,
message: str = "Neptune Analytics configuration is invalid or incomplete. Please review your setup.",
name: str = "NeptuneAnalyticsConfigurationError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError):
"""Exception raised when a Neptune Analytics operation times out."""
def __init__(
self,
message: str = "The operation timed out while communicating with Neptune Analytics.",
name: str = "NeptuneAnalyticsTimeoutError",
status_code=status.HTTP_504_GATEWAY_TIMEOUT
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError):
"""Exception raised when requests are throttled by Neptune Analytics."""
def __init__(
self,
message: str = "Request was throttled by Neptune Analytics due to exceeding rate limits.",
name: str = "NeptuneAnalyticsThrottlingError",
status_code=status.HTTP_429_TOO_MANY_REQUESTS
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError):
"""Exception raised when a Neptune Analytics resource is not found."""
def __init__(
self,
message: str = "The requested Neptune Analytics resource could not be found.",
name: str = "NeptuneAnalyticsResourceNotFoundError",
status_code=status.HTTP_404_NOT_FOUND
status_code=status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError):
"""Exception raised when invalid parameters are provided to Neptune Analytics."""
def __init__(
self,
message: str = "One or more parameters provided to Neptune Analytics are invalid or missing.",
name: str = "NeptuneAnalyticsInvalidParameterError",
status_code=status.HTTP_400_BAD_REQUEST
status_code=status.HTTP_400_BAD_REQUEST,
):
super().__init__(message, name, status_code)

View file

@ -16,40 +16,42 @@ logger = get_logger("NeptuneUtils")
def parse_neptune_url(url: str) -> Tuple[str, str]:
"""
Parse a Neptune Analytics URL to extract graph ID and region.
Expected format: neptune-graph://<GRAPH_ID>?region=<REGION>
or neptune-graph://<GRAPH_ID> (defaults to us-east-1)
Parameters:
-----------
- url (str): The Neptune Analytics URL to parse
Returns:
--------
- Tuple[str, str]: A tuple containing (graph_id, region)
Raises:
-------
- ValueError: If the URL format is invalid
"""
try:
parsed = urlparse(url)
if parsed.scheme != "neptune-graph":
raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'")
graph_id = parsed.hostname or parsed.path.lstrip('/')
graph_id = parsed.hostname or parsed.path.lstrip("/")
if not graph_id:
raise ValueError("Graph ID not found in URL")
# Extract region from query parameters
region = "us-east-1" # default region
if parsed.query:
query_params = dict(param.split('=') for param in parsed.query.split('&') if '=' in param)
region = query_params.get('region', region)
query_params = dict(
param.split("=") for param in parsed.query.split("&") if "=" in param
)
region = query_params.get("region", region)
return graph_id, region
except Exception as e:
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}")
@ -57,43 +59,43 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
def validate_graph_id(graph_id: str) -> bool:
"""
Validate a Neptune Analytics graph ID format.
Graph IDs should follow AWS naming conventions.
Parameters:
-----------
- graph_id (str): The graph ID to validate
Returns:
--------
- bool: True if the graph ID is valid, False otherwise
"""
if not graph_id:
return False
# Neptune Analytics graph IDs should be alphanumeric with hyphens
# and between 1-63 characters
pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$'
pattern = r"^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$"
return bool(re.match(pattern, graph_id))
def validate_aws_region(region: str) -> bool:
"""
Validate an AWS region format.
Parameters:
-----------
- region (str): The AWS region to validate
Returns:
--------
- bool: True if the region format is valid, False otherwise
"""
if not region:
return False
# AWS regions follow the pattern: us-east-1, eu-west-1, etc.
pattern = r'^[a-z]{2,3}-[a-z]+-\d+$'
pattern = r"^[a-z]{2,3}-[a-z]+-\d+$"
return bool(re.match(pattern, region))
@ -103,11 +105,11 @@ def build_neptune_config(
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Build a configuration dictionary for Neptune Analytics connection.
Parameters:
-----------
- graph_id (str): The Neptune Analytics graph identifier
@ -116,11 +118,11 @@ def build_neptune_config(
- aws_secret_access_key (Optional[str]): AWS secret access key
- aws_session_token (Optional[str]): AWS session token for temporary credentials
- **kwargs: Additional configuration parameters
Returns:
--------
- Dict[str, Any]: Configuration dictionary for Neptune Analytics
Raises:
-------
- ValueError: If required parameters are invalid
@ -129,35 +131,35 @@ def build_neptune_config(
"graph_id": graph_id,
"service_name": "neptune-graph",
}
# Add AWS credentials if provided
if region:
config["region"] = region
if aws_access_key_id:
config["aws_access_key_id"] = aws_access_key_id
if aws_secret_access_key:
config["aws_secret_access_key"] = aws_secret_access_key
if aws_session_token:
config["aws_session_token"] = aws_session_token
# Add any additional configuration
config.update(kwargs)
return config
def get_neptune_endpoint_url(graph_id: str, region: str) -> str:
"""
Construct the Neptune Analytics endpoint URL for a given graph and region.
Parameters:
-----------
- graph_id (str): The Neptune Analytics graph identifier
- region (str): AWS region where the graph is located
Returns:
--------
- str: The Neptune Analytics endpoint URL
@ -168,17 +170,17 @@ def get_neptune_endpoint_url(graph_id: str, region: str) -> str:
def format_neptune_error(error: Exception) -> str:
"""
Format Neptune Analytics specific errors for better readability.
Parameters:
-----------
- error (Exception): The exception to format
Returns:
--------
- str: Formatted error message
"""
error_msg = str(error)
# Common Neptune Analytics error patterns and their user-friendly messages
error_mappings = {
"AccessDenied": "Access denied. Please check your AWS credentials and permissions.",
@ -187,17 +189,18 @@ def format_neptune_error(error: Exception) -> str:
"ThrottlingException": "Request was throttled. Please retry with exponential backoff.",
"InternalServerError": "Internal server error occurred. Please try again later.",
}
for error_type, friendly_msg in error_mappings.items():
if error_type in error_msg:
return f"{friendly_msg} Original error: {error_msg}"
return error_msg
def get_default_query_timeout() -> int:
"""
Get the default query timeout for Neptune Analytics operations.
Returns:
--------
- int: Default timeout in seconds
@ -208,7 +211,7 @@ def get_default_query_timeout() -> int:
def get_default_connection_config() -> Dict[str, Any]:
"""
Get default connection configuration for Neptune Analytics.
Returns:
--------
- Dict[str, Any]: Default connection configuration

View file

@ -17,6 +17,7 @@ from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredRes
logger = get_logger("NeptuneAnalyticsAdapter")
class IndexSchema(DataPoint):
"""
Represents a schema for an index data point containing an ID and text.
@ -27,16 +28,19 @@ class IndexSchema(DataPoint):
- metadata: A dictionary with default index fields for the schema, currently configured
to include 'text'.
"""
id: str
text: str
metadata: dict = {"index_fields": ["text"]}
NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://"
class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
"""
Hybrid adapter that combines Neptune Analytics Vector and Graph functionality.
This adapter extends NeptuneGraphDB and implements VectorDBInterface to provide
a unified interface for working with Neptune Analytics as both a vector store
and a graph database.
@ -48,14 +52,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
_TOPK_UPPER_BOUND = 10
def __init__(
self,
graph_id: str,
embedding_engine: Optional[EmbeddingEngine] = None,
region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
):
self,
graph_id: str,
embedding_engine: Optional[EmbeddingEngine] = None,
region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
):
"""
Initialize the Neptune Analytics hybrid adapter.
@ -74,12 +78,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
region=region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token
aws_session_token=aws_session_token,
)
# Add vector-specific attributes
self.embedding_engine = embedding_engine
logger.info(f"Initialized Neptune Analytics hybrid adapter for graph: \"{graph_id}\" in region: \"{self.region}\"")
logger.info(
f'Initialized Neptune Analytics hybrid adapter for graph: "{graph_id}" in region: "{self.region}"'
)
# VectorDBInterface methods implementation
@ -161,7 +167,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
# Fetch embeddings
texts = [DataPoint.get_embeddable_data(t) for t in data_points]
data_vectors = (await self.embedding_engine.embed_text(texts))
data_vectors = await self.embedding_engine.embed_text(texts)
for index, data_point in enumerate(data_points):
node_id = data_point.id
@ -172,23 +178,24 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
properties = self._serialize_properties(data_point.model_dump())
properties[self._COLLECTION_PREFIX] = collection_name
params = dict(
node_id = str(node_id),
properties = properties,
embedding = data_vector,
collection_name = collection_name
node_id=str(node_id),
properties=properties,
embedding=data_vector,
collection_name=collection_name,
)
# Compose the query and send
query_string = (
f"MERGE (n "
f":{self._VECTOR_NODE_LABEL} "
f" {{`~id`: $node_id}}) "
f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
f"WITH n, $embedding AS embedding "
f"CALL neptune.algo.vectors.upsert(n, embedding) "
f"YIELD success "
f"RETURN success ")
f"MERGE (n "
f":{self._VECTOR_NODE_LABEL} "
f" {{`~id`: $node_id}}) "
f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
f"WITH n, $embedding AS embedding "
f"CALL neptune.algo.vectors.upsert(n, embedding) "
f"YIELD success "
f"RETURN success "
)
try:
self._client.query(query_string, params)
@ -208,10 +215,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
"""
# Do the fetch for each node
params = dict(node_ids=data_point_ids, collection_name=collection_name)
query_string = (f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
f"WHERE id(n) in $node_ids AND "
f"n.{self._COLLECTION_PREFIX} = $collection_name "
f"RETURN n as payload ")
query_string = (
f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
f"WHERE id(n) in $node_ids AND "
f"n.{self._COLLECTION_PREFIX} = $collection_name "
f"RETURN n as payload "
)
try:
result = self._client.query(query_string, params)
@ -259,7 +268,8 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
logger.warning(
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
"Defaulting to limit=10.", limit
"Defaulting to limit=10.",
limit,
)
limit = self._TOPK_UPPER_BOUND
@ -272,7 +282,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
elif query_vector:
embedding = query_vector
else:
data_vectors = (await self.embedding_engine.embed_text([query_text]))
data_vectors = await self.embedding_engine.embed_text([query_text])
embedding = data_vectors[0]
# Compose the parameters map
@ -305,9 +315,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
try:
query_response = self._client.query(query_string, params)
return [self._get_scored_result(
item = item, with_score = True
) for item in query_response]
return [self._get_scored_result(item=item, with_score=True) for item in query_response]
except Exception as e:
self._na_exception_handler(e, query_string)
@ -332,11 +340,13 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
self._validate_embedding_engine()
# Convert text to embedding array in batch
data_vectors = (await self.embedding_engine.embed_text(query_texts))
return await asyncio.gather(*[
self.search(collection_name, None, vector, limit, with_vectors)
for vector in data_vectors
])
data_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(collection_name, None, vector, limit, with_vectors)
for vector in data_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""
@ -350,10 +360,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
- data_point_ids (list[str]): A list of IDs of the data points to delete.
"""
params = dict(node_ids=data_point_ids, collection_name=collection_name)
query_string = (f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
f"WHERE id(n) IN $node_ids "
f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
f"DETACH DELETE n")
query_string = (
f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
f"WHERE id(n) IN $node_ids "
f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
f"DETACH DELETE n"
)
try:
self._client.query(query_string, params)
except Exception as e:
@ -370,7 +382,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
"""
Indexes a list of data points into Neptune Analytics by creating them as nodes.
@ -402,29 +414,28 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
Remove obsolete or unnecessary data from the database.
"""
# Run actual truncate
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
f"DETACH DELETE n")
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n")
pass
@staticmethod
def _get_scored_result(item: dict, with_vector: bool = False, with_score: bool = False) -> ScoredResult:
def _get_scored_result(
item: dict, with_vector: bool = False, with_score: bool = False
) -> ScoredResult:
"""
Util method to simplify the object creation of ScoredResult base on incoming NX payload response.
"""
return ScoredResult(
id=item.get('payload').get('~id'),
payload=item.get('payload').get('~properties'),
score=item.get('score') if with_score else 0,
vector=item.get('embedding') if with_vector else None
id=item.get("payload").get("~id"),
payload=item.get("payload").get("~properties"),
score=item.get("score") if with_score else 0,
vector=item.get("embedding") if with_vector else None,
)
def _na_exception_handler(self, ex, query_string: str):
"""
Generic exception handler for NA langchain.
"""
logger.error(
"Neptune Analytics query failed: %s | Query: [%s]", ex, query_string
)
logger.error("Neptune Analytics query failed: %s | Query: [%s]", ex, query_string)
raise ex
def _validate_embedding_engine(self):
@ -433,4 +444,6 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
:raises: ValueError if this object does not have a valid embedding_engine
"""
if self.embedding_engine is None:
raise ValueError("Neptune Analytics requires an embedder defined to make vector operations")
raise ValueError(
"Neptune Analytics requires an embedder defined to make vector operations"
)

View file

@ -125,9 +125,15 @@ def create_vector_engine(
if not vector_db_url:
raise EnvironmentError("Missing Neptune endpoint.")
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
NEPTUNE_ANALYTICS_ENDPOINT_URL,
)
if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
raise ValueError(
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
)
graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")

View file

@ -8,7 +8,7 @@ from cognee.modules.data.processing.document_types import TextDocument
# Set up Amazon credentials in .env file and get the values from environment variables
load_dotenv()
graph_id = os.getenv('GRAPH_ID', "")
graph_id = os.getenv("GRAPH_ID", "")
na_adapter = NeptuneGraphDB(graph_id)
@ -24,33 +24,33 @@ def setup():
# stored in Amazon S3.
document = TextDocument(
name='text_test.txt',
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text_test.txt',
external_metadata='{}',
mime_type='text/plain'
name="text_test.txt",
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt",
external_metadata="{}",
mime_type="text/plain",
)
document_chunk = DocumentChunk(
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
chunk_size=187,
chunk_index=0,
cut_type='paragraph_end',
cut_type="paragraph_end",
is_part_of=document,
)
graph_database = EntityType(name='graph database', description='graph database')
graph_database = EntityType(name="graph database", description="graph database")
neptune_analytics_entity = Entity(
name='neptune analytics',
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
name="neptune analytics",
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
)
neptune_database_entity = Entity(
name='amazon neptune database',
description='A popular managed graph database that complements Neptune Analytics.',
name="amazon neptune database",
description="A popular managed graph database that complements Neptune Analytics.",
)
storage = EntityType(name='storage', description='storage')
storage = EntityType(name="storage", description="storage")
storage_entity = Entity(
name='amazon s3',
description='A storage service provided by Amazon Web Services that allows storing graph data.',
name="amazon s3",
description="A storage service provided by Amazon Web Services that allows storing graph data.",
)
nodes_data = [
@ -67,37 +67,37 @@ def setup():
(
str(document_chunk.id),
str(storage_entity.id),
'contains',
"contains",
),
(
str(storage_entity.id),
str(storage.id),
'is_a',
"is_a",
),
(
str(document_chunk.id),
str(neptune_database_entity.id),
'contains',
"contains",
),
(
str(neptune_database_entity.id),
str(graph_database.id),
'is_a',
"is_a",
),
(
str(document_chunk.id),
str(document.id),
'is_part_of',
"is_part_of",
),
(
str(document_chunk.id),
str(neptune_analytics_entity.id),
'contains',
"contains",
),
(
str(neptune_analytics_entity.id),
str(graph_database.id),
'is_a',
"is_a",
),
]
@ -155,42 +155,44 @@ async def pipeline_method():
print("------NEIGHBORING NODES-------")
center_node = nodes[2]
neighbors = await na_adapter.get_neighbors(str(center_node.id))
print(f"found {len(neighbors)} neighbors for node \"{center_node.name}\"")
print(f'found {len(neighbors)} neighbors for node "{center_node.name}"')
for neighbor in neighbors:
print(neighbor)
print("------NEIGHBORING EDGES-------")
center_node = nodes[2]
neighbouring_edges = await na_adapter.get_edges(str(center_node.id))
print(f"found {len(neighbouring_edges)} edges neighbouring node \"{center_node.name}\"")
print(f'found {len(neighbouring_edges)} edges neighbouring node "{center_node.name}"')
for edge in neighbouring_edges:
print(edge)
print("------GET CONNECTIONS (SOURCE NODE)-------")
document_chunk_node = nodes[0]
connections = await na_adapter.get_connections(str(document_chunk_node.id))
print(f"found {len(connections)} connections for node \"{document_chunk_node.type}\"")
print(f'found {len(connections)} connections for node "{document_chunk_node.type}"')
for connection in connections:
src, relationship, tgt = connection
src = src.get("name", src.get("type", "unknown"))
relationship = relationship["relationship_name"]
tgt = tgt.get("name", tgt.get("type", "unknown"))
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
print(f'"{src}"-[{relationship}]->"{tgt}"')
print("------GET CONNECTIONS (TARGET NODE)-------")
connections = await na_adapter.get_connections(str(center_node.id))
print(f"found {len(connections)} connections for node \"{center_node.name}\"")
print(f'found {len(connections)} connections for node "{center_node.name}"')
for connection in connections:
src, relationship, tgt = connection
src = src.get("name", src.get("type", "unknown"))
relationship = relationship["relationship_name"]
tgt = tgt.get("name", tgt.get("type", "unknown"))
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
print(f'"{src}"-[{relationship}]->"{tgt}"')
print("------SUBGRAPH-------")
node_names = ["neptune analytics", "amazon neptune database"]
subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names)
print(f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}")
print(
f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}"
)
for subgraph_node in subgraph_nodes:
print(subgraph_node)
for subgraph_edge in subgraph_edges:
@ -199,17 +201,17 @@ async def pipeline_method():
print("------STAT-------")
stat = await na_adapter.get_graph_metrics(include_optional=True)
assert type(stat) is dict
assert stat['num_nodes'] == 7
assert stat['num_edges'] == 7
assert stat['mean_degree'] == 2.0
assert round(stat['edge_density'], 3) == 0.167
assert stat['num_connected_components'] == [7]
assert stat['sizes_of_connected_components'] == 1
assert stat['num_selfloops'] == 0
assert stat["num_nodes"] == 7
assert stat["num_edges"] == 7
assert stat["mean_degree"] == 2.0
assert round(stat["edge_density"], 3) == 0.167
assert stat["num_connected_components"] == [7]
assert stat["sizes_of_connected_components"] == 1
assert stat["num_selfloops"] == 0
# Unsupported optional metrics
assert stat['diameter'] == -1
assert stat['avg_shortest_path_length'] == -1
assert stat['avg_clustering'] == -1
assert stat["diameter"] == -1
assert stat["avg_shortest_path_length"] == -1
assert stat["avg_clustering"] == -1
print("------DELETE-------")
# delete all nodes and edges:
@ -253,7 +255,9 @@ async def misc_methods():
print(edge_labels)
print("------Get Filtered Graph-------")
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data([{'name': ['text_test.txt']}])
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data(
[{"name": ["text_test.txt"]}]
)
print(filtered_nodes, filtered_edges)
print("------Get Degree one nodes-------")
@ -261,15 +265,13 @@ async def misc_methods():
print(degree_one_nodes)
print("------Get Doc sub-graph-------")
doc_sub_graph = await na_adapter.get_document_subgraph('test.txt')
doc_sub_graph = await na_adapter.get_document_subgraph("test.txt")
print(doc_sub_graph)
print("------Fetch and Remove connections (Predecessors)-------")
# Fetch test edge
(src_id, dest_id, relationship) = edges[0]
nodes_predecessors = await na_adapter.get_predecessors(
node_id=dest_id, edge_label=relationship
)
nodes_predecessors = await na_adapter.get_predecessors(node_id=dest_id, edge_label=relationship)
assert len(nodes_predecessors) > 0
await na_adapter.remove_connection_to_predecessors_of(
@ -281,25 +283,19 @@ async def misc_methods():
# Return empty after relationship being deleted.
assert len(nodes_predecessors_after) == 0
print("------Fetch and Remove connections (Successors)-------")
_, edges_suc = await na_adapter.get_graph_data()
(src_id, dest_id, relationship, _) = edges_suc[0]
nodes_successors = await na_adapter.get_successors(
node_id=src_id, edge_label=relationship
)
nodes_successors = await na_adapter.get_successors(node_id=src_id, edge_label=relationship)
assert len(nodes_successors) > 0
await na_adapter.remove_connection_to_successors_of(
node_ids=[dest_id], edge_label=relationship
)
await na_adapter.remove_connection_to_successors_of(node_ids=[dest_id], edge_label=relationship)
nodes_successors_after = await na_adapter.get_successors(
node_id=src_id, edge_label=relationship
)
assert len(nodes_successors_after) == 0
# no-op
await na_adapter.project_entire_graph()
await na_adapter.drop_graph()

View file

@ -8,11 +8,13 @@ from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.data.processing.document_types import TextDocument
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
)
# Set up Amazon credentials in .env file and get the values from environment variables
load_dotenv()
graph_id = os.getenv('GRAPH_ID', "")
graph_id = os.getenv("GRAPH_ID", "")
# get the default embedder
embedding_engine = get_embedding_engine()
@ -23,6 +25,7 @@ collection = "test_collection"
logger = get_logger("test_neptune_analytics_hybrid")
def setup_data():
# Define nodes data before the main function
# These nodes were defined using openAI from the following prompt:
@ -34,35 +37,35 @@ def setup_data():
# stored in Amazon S3.
document = TextDocument(
name='text.txt',
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text.txt',
external_metadata='{}',
mime_type='text/plain'
name="text.txt",
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text.txt",
external_metadata="{}",
mime_type="text/plain",
)
document_chunk = DocumentChunk(
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
chunk_size=187,
chunk_index=0,
cut_type='paragraph_end',
cut_type="paragraph_end",
is_part_of=document,
)
graph_database = EntityType(name='graph database', description='graph database')
graph_database = EntityType(name="graph database", description="graph database")
neptune_analytics_entity = Entity(
name='neptune analytics',
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
name="neptune analytics",
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
)
neptune_database_entity = Entity(
name='amazon neptune database',
description='A popular managed graph database that complements Neptune Analytics.',
name="amazon neptune database",
description="A popular managed graph database that complements Neptune Analytics.",
)
storage = EntityType(name='storage', description='storage')
storage = EntityType(name="storage", description="storage")
storage_entity = Entity(
name='amazon s3',
description='A storage service provided by Amazon Web Services that allows storing graph data.',
name="amazon s3",
description="A storage service provided by Amazon Web Services that allows storing graph data.",
)
nodes_data = [
document,
document_chunk,
@ -77,41 +80,42 @@ def setup_data():
(
str(document_chunk.id),
str(storage_entity.id),
'contains',
"contains",
),
(
str(storage_entity.id),
str(storage.id),
'is_a',
"is_a",
),
(
str(document_chunk.id),
str(neptune_database_entity.id),
'contains',
"contains",
),
(
str(neptune_database_entity.id),
str(graph_database.id),
'is_a',
"is_a",
),
(
str(document_chunk.id),
str(document.id),
'is_part_of',
"is_part_of",
),
(
str(document_chunk.id),
str(neptune_analytics_entity.id),
'contains',
"contains",
),
(
str(neptune_analytics_entity.id),
str(graph_database.id),
'is_a',
"is_a",
),
]
return nodes_data, edges_data
async def test_add_graph_then_vector_data():
logger.info("------test_add_graph_then_vector_data-------")
(nodes, edges) = setup_data()
@ -134,6 +138,7 @@ async def test_add_graph_then_vector_data():
assert len(edges) == 0
logger.info("------PASSED-------")
async def test_add_vector_then_node_data():
logger.info("------test_add_vector_then_node_data-------")
(nodes, edges) = setup_data()
@ -156,6 +161,7 @@ async def test_add_vector_then_node_data():
assert len(edges) == 0
logger.info("------PASSED-------")
def main():
"""
Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data
@ -165,5 +171,6 @@ def main():
asyncio.run(test_add_graph_then_vector_data())
asyncio.run(test_add_vector_then_node_data())
if __name__ == "__main__":
main()

View file

@ -8,13 +8,16 @@ from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, IndexSchema
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
IndexSchema,
)
logger = get_logger()
async def main():
graph_id = os.getenv('GRAPH_ID', "")
graph_id = os.getenv("GRAPH_ID", "")
cognee.config.set_vector_db_provider("neptune_analytics")
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
data_directory_path = str(
@ -87,6 +90,7 @@ async def main():
await cognee.prune.prune_system(metadata=True)
async def vector_backend_api_test():
cognee.config.set_vector_db_provider("neptune_analytics")
@ -101,7 +105,7 @@ async def vector_backend_api_test():
get_vector_engine()
# Return a valid engine object with valid URL.
graph_id = os.getenv('GRAPH_ID', "")
graph_id = os.getenv("GRAPH_ID", "")
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
engine = get_vector_engine()
assert isinstance(engine, NeptuneAnalyticsAdapter)
@ -133,28 +137,22 @@ async def vector_backend_api_test():
query_text=TEST_TEXT,
query_vector=None,
limit=10,
with_vector=True)
assert (len(result_search) == 2)
with_vector=True,
)
assert len(result_search) == 2
# # Retrieve data-points
result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
assert any(
str(r.id) == TEST_UUID and r.payload['text'] == TEST_TEXT
for r in result
)
assert any(
str(r.id) == TEST_UUID_2 and r.payload['text'] == TEST_TEXT_2
for r in result
)
assert any(str(r.id) == TEST_UUID and r.payload["text"] == TEST_TEXT for r in result)
assert any(str(r.id) == TEST_UUID_2 and r.payload["text"] == TEST_TEXT_2 for r in result)
# Search multiple
result_search_batch = await engine.batch_search(
collection_name=TEST_COLLECTION_NAME,
query_texts=[TEST_TEXT, TEST_TEXT_2],
limit=10,
with_vectors=False
with_vectors=False,
)
assert (len(result_search_batch) == 2 and
all(len(batch) == 2 for batch in result_search_batch))
assert len(result_search_batch) == 2 and all(len(batch) == 2 for batch in result_search_batch)
# Delete datapoint from vector store
await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
@ -163,6 +161,7 @@ async def vector_backend_api_test():
result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID])
assert result_deleted == []
if __name__ == "__main__":
import asyncio

View file

@ -9,6 +9,7 @@ from dotenv import load_dotenv
load_dotenv()
async def main():
"""
Example script demonstrating how to use Cognee with Amazon Neptune Analytics
@ -22,7 +23,7 @@ async def main():
"""
# Set up Amazon credentials in .env file and get the values from environment variables
graph_endpoint_url = "neptune-graph://" + os.getenv('GRAPH_ID', "")
graph_endpoint_url = "neptune-graph://" + os.getenv("GRAPH_ID", "")
# Configure Neptune Analytics as the graph & vector database provider
cognee.config.set_graph_db_config(
@ -77,7 +78,9 @@ async def main():
# Now let's perform some searches
# 1. Search for insights related to "Neptune Analytics"
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neptune Analytics")
insights_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text="Neptune Analytics"
)
print("\n========Insights about Neptune Analytics========:")
for result in insights_results:
print(f"- {result}")

File diff suppressed because one or more lines are too long