added linting and formatting

This commit is contained in:
vasilije 2025-08-02 16:58:18 +02:00
parent ae8fc7f0c9
commit c0d60eef25
11 changed files with 355 additions and 326 deletions

View file

@ -149,7 +149,9 @@ def create_graph_engine(
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL): if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
raise ValueError(f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>") raise ValueError(
f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "") graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
@ -174,10 +176,15 @@ def create_graph_engine(
if not graph_database_url: if not graph_database_url:
raise EnvironmentError("Missing Neptune endpoint.") raise EnvironmentError("Missing Neptune endpoint.")
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
NEPTUNE_ANALYTICS_ENDPOINT_URL,
)
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL): if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'") raise ValueError(
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
)
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "") graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")

View file

@ -29,6 +29,7 @@ logger = get_logger("NeptuneGraphDB")
try: try:
from langchain_aws import NeptuneAnalyticsGraph from langchain_aws import NeptuneAnalyticsGraph
LANGCHAIN_AWS_AVAILABLE = True LANGCHAIN_AWS_AVAILABLE = True
except ImportError: except ImportError:
logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.") logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.")
@ -36,11 +37,13 @@ except ImportError:
NEPTUNE_ENDPOINT_URL = "neptune-graph://" NEPTUNE_ENDPOINT_URL = "neptune-graph://"
class NeptuneGraphDB(GraphDBInterface): class NeptuneGraphDB(GraphDBInterface):
""" """
Adapter for interacting with Amazon Neptune Analytics graph store. Adapter for interacting with Amazon Neptune Analytics graph store.
This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library. This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library.
""" """
_GRAPH_NODE_LABEL = "COGNEE_NODE" _GRAPH_NODE_LABEL = "COGNEE_NODE"
def __init__( def __init__(
@ -68,14 +71,16 @@ class NeptuneGraphDB(GraphDBInterface):
""" """
# validate import # validate import
if not LANGCHAIN_AWS_AVAILABLE: if not LANGCHAIN_AWS_AVAILABLE:
raise ImportError("langchain_aws is not available. Please install it to use Neptune Analytics.") raise ImportError(
"langchain_aws is not available. Please install it to use Neptune Analytics."
)
# Validate configuration # Validate configuration
if not validate_graph_id(graph_id): if not validate_graph_id(graph_id):
raise NeptuneAnalyticsConfigurationError(message=f"Invalid graph ID: \"{graph_id}\"") raise NeptuneAnalyticsConfigurationError(message=f'Invalid graph ID: "{graph_id}"')
if region and not validate_aws_region(region): if region and not validate_aws_region(region):
raise NeptuneAnalyticsConfigurationError(message=f"Invalid AWS region: \"{region}\"") raise NeptuneAnalyticsConfigurationError(message=f'Invalid AWS region: "{region}"')
self.graph_id = graph_id self.graph_id = graph_id
self.region = region self.region = region
@ -94,7 +99,9 @@ class NeptuneGraphDB(GraphDBInterface):
# Initialize Neptune Analytics client using langchain_aws # Initialize Neptune Analytics client using langchain_aws
self._client: NeptuneAnalyticsGraph = self._initialize_client() self._client: NeptuneAnalyticsGraph = self._initialize_client()
logger.info(f"Initialized Neptune Analytics adapter for graph: \"{graph_id}\" in region: \"{self.region}\"") logger.info(
f'Initialized Neptune Analytics adapter for graph: "{graph_id}" in region: "{self.region}"'
)
def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]: def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]:
""" """
@ -108,9 +115,7 @@ class NeptuneGraphDB(GraphDBInterface):
# Initialize the Neptune Analytics Graph client # Initialize the Neptune Analytics Graph client
client_config = { client_config = {
"graph_identifier": self.graph_id, "graph_identifier": self.graph_id,
"config": Config( "config": Config(user_agent_appid="Cognee"),
user_agent_appid='Cognee'
)
} }
# Add AWS credentials if provided # Add AWS credentials if provided
if self.region: if self.region:
@ -127,7 +132,9 @@ class NeptuneGraphDB(GraphDBInterface):
return client return client
except Exception as e: except Exception as e:
raise NeptuneAnalyticsConfigurationError(message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}") from e raise NeptuneAnalyticsConfigurationError(
message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}"
) from e
@staticmethod @staticmethod
def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]: def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]:
@ -215,6 +222,7 @@ class NeptuneGraphDB(GraphDBInterface):
result = await self.query(query, params) result = await self.query(query, params)
logger.debug(f"Successfully added/updated node: {node.id}") logger.debug(f"Successfully added/updated node: {node.id}")
logger.debug(f"Successfully gotten: {str(result)}")
except Exception as e: except Exception as e:
error_msg = format_neptune_error(e) error_msg = format_neptune_error(e)
@ -256,7 +264,7 @@ class NeptuneGraphDB(GraphDBInterface):
} }
result = await self.query(query, params) result = await self.query(query, params)
processed_count = result[0].get('nodes_processed', 0) if result else 0 processed_count = result[0].get("nodes_processed", 0) if result else 0
logger.debug(f"Successfully processed {processed_count} nodes in bulk operation") logger.debug(f"Successfully processed {processed_count} nodes in bulk operation")
except Exception as e: except Exception as e:
@ -268,7 +276,9 @@ class NeptuneGraphDB(GraphDBInterface):
try: try:
await self.add_node(node) await self.add_node(node)
except Exception as node_error: except Exception as node_error:
logger.error(f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}") logger.error(
f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}"
)
continue continue
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
@ -287,9 +297,7 @@ class NeptuneGraphDB(GraphDBInterface):
DETACH DELETE n DETACH DELETE n
""" """
params = { params = {"node_id": node_id}
"node_id": node_id
}
await self.query(query, params) await self.query(query, params)
logger.debug(f"Successfully deleted node: {node_id}") logger.debug(f"Successfully deleted node: {node_id}")
@ -333,7 +341,9 @@ class NeptuneGraphDB(GraphDBInterface):
try: try:
await self.delete_node(node_id) await self.delete_node(node_id)
except Exception as node_error: except Exception as node_error:
logger.error(f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}") logger.error(
f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}"
)
continue continue
async def get_node(self, node_id: str) -> Optional[NodeData]: async def get_node(self, node_id: str) -> Optional[NodeData]:
@ -355,7 +365,7 @@ class NeptuneGraphDB(GraphDBInterface):
WHERE id(n) = $node_id WHERE id(n) = $node_id
RETURN n RETURN n
""" """
params = {'node_id': node_id} params = {"node_id": node_id}
result = await self.query(query, params) result = await self.query(query, params)
@ -406,7 +416,9 @@ class NeptuneGraphDB(GraphDBInterface):
# Extract node data from results # Extract node data from results
nodes = [record["n"] for record in result] nodes = [record["n"] for record in result]
logger.debug(f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested") logger.debug(
f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested"
)
return nodes return nodes
except Exception as e: except Exception as e:
@ -421,11 +433,12 @@ class NeptuneGraphDB(GraphDBInterface):
if node_data: if node_data:
nodes.append(node_data) nodes.append(node_data)
except Exception as node_error: except Exception as node_error:
logger.error(f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}") logger.error(
f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}"
)
continue continue
return nodes return nodes
async def extract_node(self, node_id: str): async def extract_node(self, node_id: str):
""" """
Retrieve a single node based on its ID. Retrieve a single node based on its ID.
@ -512,7 +525,9 @@ class NeptuneGraphDB(GraphDBInterface):
"properties": serialized_properties, "properties": serialized_properties,
} }
await self.query(query, params) await self.query(query, params)
logger.debug(f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}") logger.debug(
f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}"
)
except Exception as e: except Exception as e:
error_msg = format_neptune_error(e) error_msg = format_neptune_error(e)
@ -557,20 +572,24 @@ class NeptuneGraphDB(GraphDBInterface):
""" """
# Prepare edges data for bulk operation # Prepare edges data for bulk operation
params = {"edges": params = {
[ "edges": [
{ {
"from_node": str(edge[0]), "from_node": str(edge[0]),
"to_node": str(edge[1]), "to_node": str(edge[1]),
"relationship_name": relationship_name, "relationship_name": relationship_name,
"properties": self._serialize_properties(edge[3] if len(edge) > 3 and edge[3] else {}), "properties": self._serialize_properties(
edge[3] if len(edge) > 3 and edge[3] else {}
),
} }
for edge in edges_for_relationship for edge in edges_for_relationship
] ]
} }
results[relationship_name] = await self.query(query, params) results[relationship_name] = await self.query(query, params)
except Exception as e: except Exception as e:
logger.error(f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}") logger.error(
f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}"
)
logger.info("Falling back to individual edge creation") logger.info("Falling back to individual edge creation")
for edge in edges_by_relationship: for edge in edges_by_relationship:
try: try:
@ -578,15 +597,16 @@ class NeptuneGraphDB(GraphDBInterface):
properties = edge[3] if len(edge) > 3 else {} properties = edge[3] if len(edge) > 3 else {}
await self.add_edge(source_id, target_id, relationship_name, properties) await self.add_edge(source_id, target_id, relationship_name, properties)
except Exception as edge_error: except Exception as edge_error:
logger.error(f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}") logger.error(
f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}"
)
continue continue
processed_count = 0 processed_count = 0
for result in results.values(): for result in results.values():
processed_count += result[0].get('edges_processed', 0) if result else 0 processed_count += result[0].get("edges_processed", 0) if result else 0
logger.debug(f"Successfully processed {processed_count} edges in bulk operation") logger.debug(f"Successfully processed {processed_count} edges in bulk operation")
async def delete_graph(self) -> None: async def delete_graph(self) -> None:
""" """
Delete all nodes and edges from the graph database. Delete all nodes and edges from the graph database.
@ -599,14 +619,12 @@ class NeptuneGraphDB(GraphDBInterface):
# Build openCypher query to delete the graph # Build openCypher query to delete the graph
query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n" query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n"
await self.query(query) await self.query(query)
logger.debug(f"Successfully deleted all edges and nodes from the graph")
except Exception as e: except Exception as e:
error_msg = format_neptune_error(e) error_msg = format_neptune_error(e)
logger.error(f"Failed to delete graph: {error_msg}") logger.error(f"Failed to delete graph: {error_msg}")
raise Exception(f"Failed to delete graph: {error_msg}") from e raise Exception(f"Failed to delete graph: {error_msg}") from e
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]: async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
""" """
Retrieve all nodes and edges within the graph. Retrieve all nodes and edges within the graph.
@ -633,13 +651,7 @@ class NeptuneGraphDB(GraphDBInterface):
edges_result = await self.query(edges_query) edges_result = await self.query(edges_query)
# Format nodes as (node_id, properties) tuples # Format nodes as (node_id, properties) tuples
nodes = [ nodes = [(result["node_id"], result["properties"]) for result in nodes_result]
(
result["node_id"],
result["properties"]
)
for result in nodes_result
]
# Format edges as (source_id, target_id, relationship_name, properties) tuples # Format edges as (source_id, target_id, relationship_name, properties) tuples
edges = [ edges = [
@ -647,7 +659,7 @@ class NeptuneGraphDB(GraphDBInterface):
result["source_id"], result["source_id"],
result["target_id"], result["target_id"],
result["relationship_name"], result["relationship_name"],
result["properties"] result["properties"],
) )
for result in edges_result for result in edges_result
] ]
@ -679,9 +691,11 @@ class NeptuneGraphDB(GraphDBInterface):
"num_nodes": num_nodes, "num_nodes": num_nodes,
"num_edges": num_edges, "num_edges": num_edges,
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None, "mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1)) if num_nodes != 0 else None, "edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1))
if num_nodes != 0
else None,
"num_connected_components": num_cluster, "num_connected_components": num_cluster,
"sizes_of_connected_components": list_clsuter_size "sizes_of_connected_components": list_clsuter_size,
} }
optional_metrics = { optional_metrics = {
@ -692,7 +706,7 @@ class NeptuneGraphDB(GraphDBInterface):
} }
if include_optional: if include_optional:
optional_metrics['num_selfloops'] = await self._count_self_loops() optional_metrics["num_selfloops"] = await self._count_self_loops()
# Unsupported due to long-running queries when computing the shortest path for each node in the graph: # Unsupported due to long-running queries when computing the shortest path for each node in the graph:
# optional_metrics['diameter'] # optional_metrics['diameter']
# optional_metrics['avg_shortest_path_length'] # optional_metrics['avg_shortest_path_length']
@ -732,9 +746,11 @@ class NeptuneGraphDB(GraphDBInterface):
result = await self.query(query, params) result = await self.query(query, params)
if result and len(result) > 0: if result and len(result) > 0:
edge_exists = result.pop().get('edge_exists', False) edge_exists = result.pop().get("edge_exists", False)
logger.debug(f"Edge existence check for " logger.debug(
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}") f"Edge existence check for "
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}"
)
return edge_exists return edge_exists
else: else:
return False return False
@ -863,10 +879,7 @@ class NeptuneGraphDB(GraphDBInterface):
RETURN predecessor RETURN predecessor
""" """
results = await self.query( results = await self.query(query, {"node_id": node_id})
query,
{"node_id": node_id}
)
return [result["predecessor"] for result in results] return [result["predecessor"] for result in results]
@ -893,14 +906,10 @@ class NeptuneGraphDB(GraphDBInterface):
RETURN successor RETURN successor
""" """
results = await self.query( results = await self.query(query, {"node_id": node_id})
query,
{"node_id": node_id}
)
return [result["successor"] for result in results] return [result["successor"] for result in results]
async def get_neighbors(self, node_id: str) -> List[NodeData]: async def get_neighbors(self, node_id: str) -> List[NodeData]:
""" """
Get all neighboring nodes connected to the specified node. Get all neighboring nodes connected to the specified node.
@ -926,11 +935,7 @@ class NeptuneGraphDB(GraphDBInterface):
# Format neighbors as NodeData objects # Format neighbors as NodeData objects
neighbors = [ neighbors = [
{ {"id": neighbor["neighbor_id"], **neighbor["properties"]} for neighbor in result
"id": neighbor["neighbor_id"],
**neighbor["properties"]
}
for neighbor in result
] ]
logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}") logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}")
@ -942,9 +947,7 @@ class NeptuneGraphDB(GraphDBInterface):
raise Exception(f"Failed to get neighbors: {error_msg}") from e raise Exception(f"Failed to get neighbors: {error_msg}") from e
async def get_nodeset_subgraph( async def get_nodeset_subgraph(
self, self, node_type: Type[Any], node_name: List[str]
node_type: Type[Any],
node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
""" """
Fetch a subgraph consisting of a specific set of nodes and their relationships. Fetch a subgraph consisting of a specific set of nodes and their relationships.
@ -987,10 +990,7 @@ class NeptuneGraphDB(GraphDBInterface):
}}] AS rawRels }}] AS rawRels
""" """
params = { params = {"names": node_name, "type": node_type.__name__}
"names": node_name,
"type": node_type.__name__
}
result = await self.query(query, params) result = await self.query(query, params)
@ -1002,18 +1002,14 @@ class NeptuneGraphDB(GraphDBInterface):
raw_rels = result[0]["rawRels"] raw_rels = result[0]["rawRels"]
# Format nodes as (node_id, properties) tuples # Format nodes as (node_id, properties) tuples
nodes = [ nodes = [(n["id"], n["properties"]) for n in raw_nodes]
(n["id"], n["properties"])
for n in raw_nodes
]
# Format edges as (source_id, target_id, relationship_name, properties) tuples # Format edges as (source_id, target_id, relationship_name, properties) tuples
edges = [ edges = [(r["source_id"], r["target_id"], r["type"], r["properties"]) for r in raw_rels]
(r["source_id"], r["target_id"], r["type"], r["properties"])
for r in raw_rels
]
logger.debug(f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}") logger.debug(
f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}"
)
return (nodes, edges) return (nodes, edges)
except Exception as e: except Exception as e:
@ -1055,18 +1051,12 @@ class NeptuneGraphDB(GraphDBInterface):
# Return as (source_node, relationship, target_node) # Return as (source_node, relationship, target_node)
connections.append( connections.append(
( (
{ {"id": record["source_id"], **record["source_props"]},
"id": record["source_id"],
**record["source_props"]
},
{ {
"relationship_name": record["relationship_name"], "relationship_name": record["relationship_name"],
**record["relationship_props"] **record["relationship_props"],
}, },
{ {"id": record["target_id"], **record["target_props"]},
"id": record["target_id"],
**record["target_props"]
}
) )
) )
@ -1078,10 +1068,7 @@ class NeptuneGraphDB(GraphDBInterface):
logger.error(f"Failed to get connections for node {node_id}: {error_msg}") logger.error(f"Failed to get connections for node {node_id}: {error_msg}")
raise Exception(f"Failed to get connections: {error_msg}") from e raise Exception(f"Failed to get connections: {error_msg}") from e
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str):
async def remove_connection_to_predecessors_of(
self, node_ids: list[str], edge_label: str
):
""" """
Remove connections (edges) to all predecessors of specified nodes based on edge label. Remove connections (edges) to all predecessors of specified nodes based on edge label.
@ -1101,10 +1088,7 @@ class NeptuneGraphDB(GraphDBInterface):
params = {"node_ids": node_ids} params = {"node_ids": node_ids}
await self.query(query, params) await self.query(query, params)
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str):
async def remove_connection_to_successors_of(
self, node_ids: list[str], edge_label: str
):
""" """
Remove connections (edges) to all successors of specified nodes based on edge label. Remove connections (edges) to all successors of specified nodes based on edge label.
@ -1124,7 +1108,6 @@ class NeptuneGraphDB(GraphDBInterface):
params = {"node_ids": node_ids} params = {"node_ids": node_ids}
await self.query(query, params) await self.query(query, params)
async def get_node_labels_string(self): async def get_node_labels_string(self):
""" """
Fetch all node labels from the database and return them as a formatted string. Fetch all node labels from the database and return them as a formatted string.
@ -1138,7 +1121,9 @@ class NeptuneGraphDB(GraphDBInterface):
------- -------
ValueError: If no node labels are found in the database. ValueError: If no node labels are found in the database.
""" """
node_labels_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels " node_labels_query = (
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
)
node_labels_result = await self.query(node_labels_query) node_labels_result = await self.query(node_labels_query)
node_labels = node_labels_result[0]["labels"] if node_labels_result else [] node_labels = node_labels_result[0]["labels"] if node_labels_result else []
@ -1156,7 +1141,9 @@ class NeptuneGraphDB(GraphDBInterface):
A formatted string of relationship types. A formatted string of relationship types.
""" """
relationship_types_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships " relationship_types_query = (
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
)
relationship_types_result = await self.query(relationship_types_query) relationship_types_result = await self.query(relationship_types_query)
relationship_types = ( relationship_types = (
relationship_types_result[0]["relationships"] if relationship_types_result else [] relationship_types_result[0]["relationships"] if relationship_types_result else []
@ -1172,7 +1159,6 @@ class NeptuneGraphDB(GraphDBInterface):
) )
return relationship_types_undirected_str return relationship_types_undirected_str
async def drop_graph(self, graph_name="myGraph"): async def drop_graph(self, graph_name="myGraph"):
""" """
Drop an existing graph from the database based on its name. Drop an existing graph from the database based on its name.
@ -1208,7 +1194,6 @@ class NeptuneGraphDB(GraphDBInterface):
""" """
pass pass
async def project_entire_graph(self, graph_name="myGraph"): async def project_entire_graph(self, graph_name="myGraph"):
""" """
Project all node labels and relationship types into an in-memory graph using GDS. Project all node labels and relationship types into an in-memory graph using GDS.
@ -1280,7 +1265,6 @@ class NeptuneGraphDB(GraphDBInterface):
return (nodes, edges) return (nodes, edges)
async def get_degree_one_nodes(self, node_type: str): async def get_degree_one_nodes(self, node_type: str):
""" """
Fetch nodes of a specified type that have exactly one connection. Fetch nodes of a specified type that have exactly one connection.
@ -1361,8 +1345,6 @@ class NeptuneGraphDB(GraphDBInterface):
result = await self.query(query, {"content_hash": content_hash}) result = await self.query(query, {"content_hash": content_hash})
return result[0] if result else None return result[0] if result else None
async def _get_model_independent_graph_data(self): async def _get_model_independent_graph_data(self):
""" """
Retrieve the basic graph data without considering the model specifics, returning nodes Retrieve the basic graph data without considering the model specifics, returning nodes
@ -1380,8 +1362,8 @@ class NeptuneGraphDB(GraphDBInterface):
RETURN nodeCount AS numVertices, count(r) AS numEdges RETURN nodeCount AS numVertices, count(r) AS numEdges
""" """
query_response = await self.query(query_string) query_response = await self.query(query_string)
num_nodes = query_response[0].get('numVertices') num_nodes = query_response[0].get("numVertices")
num_edges = query_response[0].get('numEdges') num_edges = query_response[0].get("numEdges")
return (num_nodes, num_edges) return (num_nodes, num_edges)
@ -1437,4 +1419,9 @@ class NeptuneGraphDB(GraphDBInterface):
@staticmethod @staticmethod
def _convert_relationship_to_edge(relationship: dict) -> EdgeData: def _convert_relationship_to_edge(relationship: dict) -> EdgeData:
return relationship["source_id"], relationship["target_id"], relationship["relationship_name"], relationship["properties"] return (
relationship["source_id"],
relationship["target_id"],
relationship["relationship_name"],
relationship["properties"],
)

View file

@ -2,106 +2,114 @@
This module defines custom exceptions for Neptune Analytics operations. This module defines custom exceptions for Neptune Analytics operations.
""" """
from cognee.exceptions import CogneeApiError from cognee.exceptions import CogneeApiError
from fastapi import status from fastapi import status
class NeptuneAnalyticsError(CogneeApiError): class NeptuneAnalyticsError(CogneeApiError):
"""Base exception for Neptune Analytics operations.""" """Base exception for Neptune Analytics operations."""
def __init__( def __init__(
self, self,
message: str = "Neptune Analytics error.", message: str = "Neptune Analytics error.",
name: str = "NeptuneAnalyticsError", name: str = "NeptuneAnalyticsError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError): class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError):
"""Exception raised when connection to Neptune Analytics fails.""" """Exception raised when connection to Neptune Analytics fails."""
def __init__( def __init__(
self, self,
message: str = "Unable to connect to Neptune Analytics. Please check the endpoint and network connectivity.", message: str = "Unable to connect to Neptune Analytics. Please check the endpoint and network connectivity.",
name: str = "NeptuneAnalyticsConnectionError", name: str = "NeptuneAnalyticsConnectionError",
status_code=status.HTTP_404_NOT_FOUND status_code=status.HTTP_404_NOT_FOUND,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsQueryError(NeptuneAnalyticsError): class NeptuneAnalyticsQueryError(NeptuneAnalyticsError):
"""Exception raised when a query execution fails.""" """Exception raised when a query execution fails."""
def __init__( def __init__(
self, self,
message: str = "The query execution failed due to invalid syntax or semantic issues.", message: str = "The query execution failed due to invalid syntax or semantic issues.",
name: str = "NeptuneAnalyticsQueryError", name: str = "NeptuneAnalyticsQueryError",
status_code=status.HTTP_400_BAD_REQUEST status_code=status.HTTP_400_BAD_REQUEST,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError): class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError):
"""Exception raised when authentication with Neptune Analytics fails.""" """Exception raised when authentication with Neptune Analytics fails."""
def __init__( def __init__(
self, self,
message: str = "Authentication with Neptune Analytics failed. Please verify your credentials.", message: str = "Authentication with Neptune Analytics failed. Please verify your credentials.",
name: str = "NeptuneAnalyticsAuthenticationError", name: str = "NeptuneAnalyticsAuthenticationError",
status_code=status.HTTP_401_UNAUTHORIZED status_code=status.HTTP_401_UNAUTHORIZED,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError): class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError):
"""Exception raised when Neptune Analytics configuration is invalid.""" """Exception raised when Neptune Analytics configuration is invalid."""
def __init__( def __init__(
self, self,
message: str = "Neptune Analytics configuration is invalid or incomplete. Please review your setup.", message: str = "Neptune Analytics configuration is invalid or incomplete. Please review your setup.",
name: str = "NeptuneAnalyticsConfigurationError", name: str = "NeptuneAnalyticsConfigurationError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError): class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError):
"""Exception raised when a Neptune Analytics operation times out.""" """Exception raised when a Neptune Analytics operation times out."""
def __init__( def __init__(
self, self,
message: str = "The operation timed out while communicating with Neptune Analytics.", message: str = "The operation timed out while communicating with Neptune Analytics.",
name: str = "NeptuneAnalyticsTimeoutError", name: str = "NeptuneAnalyticsTimeoutError",
status_code=status.HTTP_504_GATEWAY_TIMEOUT status_code=status.HTTP_504_GATEWAY_TIMEOUT,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError): class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError):
"""Exception raised when requests are throttled by Neptune Analytics.""" """Exception raised when requests are throttled by Neptune Analytics."""
def __init__( def __init__(
self, self,
message: str = "Request was throttled by Neptune Analytics due to exceeding rate limits.", message: str = "Request was throttled by Neptune Analytics due to exceeding rate limits.",
name: str = "NeptuneAnalyticsThrottlingError", name: str = "NeptuneAnalyticsThrottlingError",
status_code=status.HTTP_429_TOO_MANY_REQUESTS status_code=status.HTTP_429_TOO_MANY_REQUESTS,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError): class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError):
"""Exception raised when a Neptune Analytics resource is not found.""" """Exception raised when a Neptune Analytics resource is not found."""
def __init__( def __init__(
self, self,
message: str = "The requested Neptune Analytics resource could not be found.", message: str = "The requested Neptune Analytics resource could not be found.",
name: str = "NeptuneAnalyticsResourceNotFoundError", name: str = "NeptuneAnalyticsResourceNotFoundError",
status_code=status.HTTP_404_NOT_FOUND status_code=status.HTTP_404_NOT_FOUND,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError): class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError):
"""Exception raised when invalid parameters are provided to Neptune Analytics.""" """Exception raised when invalid parameters are provided to Neptune Analytics."""
def __init__( def __init__(
self, self,
message: str = "One or more parameters provided to Neptune Analytics are invalid or missing.", message: str = "One or more parameters provided to Neptune Analytics are invalid or missing.",
name: str = "NeptuneAnalyticsInvalidParameterError", name: str = "NeptuneAnalyticsInvalidParameterError",
status_code=status.HTTP_400_BAD_REQUEST status_code=status.HTTP_400_BAD_REQUEST,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)

View file

@ -38,15 +38,17 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
if parsed.scheme != "neptune-graph": if parsed.scheme != "neptune-graph":
raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'") raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'")
graph_id = parsed.hostname or parsed.path.lstrip('/') graph_id = parsed.hostname or parsed.path.lstrip("/")
if not graph_id: if not graph_id:
raise ValueError("Graph ID not found in URL") raise ValueError("Graph ID not found in URL")
# Extract region from query parameters # Extract region from query parameters
region = "us-east-1" # default region region = "us-east-1" # default region
if parsed.query: if parsed.query:
query_params = dict(param.split('=') for param in parsed.query.split('&') if '=' in param) query_params = dict(
region = query_params.get('region', region) param.split("=") for param in parsed.query.split("&") if "=" in param
)
region = query_params.get("region", region)
return graph_id, region return graph_id, region
@ -73,7 +75,7 @@ def validate_graph_id(graph_id: str) -> bool:
# Neptune Analytics graph IDs should be alphanumeric with hyphens # Neptune Analytics graph IDs should be alphanumeric with hyphens
# and between 1-63 characters # and between 1-63 characters
pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$' pattern = r"^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$"
return bool(re.match(pattern, graph_id)) return bool(re.match(pattern, graph_id))
@ -93,7 +95,7 @@ def validate_aws_region(region: str) -> bool:
return False return False
# AWS regions follow the pattern: us-east-1, eu-west-1, etc. # AWS regions follow the pattern: us-east-1, eu-west-1, etc.
pattern = r'^[a-z]{2,3}-[a-z]+-\d+$' pattern = r"^[a-z]{2,3}-[a-z]+-\d+$"
return bool(re.match(pattern, region)) return bool(re.match(pattern, region))
@ -103,7 +105,7 @@ def build_neptune_config(
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None, aws_session_token: Optional[str] = None,
**kwargs **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Build a configuration dictionary for Neptune Analytics connection. Build a configuration dictionary for Neptune Analytics connection.
@ -194,6 +196,7 @@ def format_neptune_error(error: Exception) -> str:
return error_msg return error_msg
def get_default_query_timeout() -> int: def get_default_query_timeout() -> int:
""" """
Get the default query timeout for Neptune Analytics operations. Get the default query timeout for Neptune Analytics operations.

View file

@ -17,6 +17,7 @@ from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredRes
logger = get_logger("NeptuneAnalyticsAdapter") logger = get_logger("NeptuneAnalyticsAdapter")
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
""" """
Represents a schema for an index data point containing an ID and text. Represents a schema for an index data point containing an ID and text.
@ -27,12 +28,15 @@ class IndexSchema(DataPoint):
- metadata: A dictionary with default index fields for the schema, currently configured - metadata: A dictionary with default index fields for the schema, currently configured
to include 'text'. to include 'text'.
""" """
id: str id: str
text: str text: str
metadata: dict = {"index_fields": ["text"]} metadata: dict = {"index_fields": ["text"]}
NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://" NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://"
class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
""" """
Hybrid adapter that combines Neptune Analytics Vector and Graph functionality. Hybrid adapter that combines Neptune Analytics Vector and Graph functionality.
@ -48,14 +52,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
_TOPK_UPPER_BOUND = 10 _TOPK_UPPER_BOUND = 10
def __init__( def __init__(
self, self,
graph_id: str, graph_id: str,
embedding_engine: Optional[EmbeddingEngine] = None, embedding_engine: Optional[EmbeddingEngine] = None,
region: Optional[str] = None, region: Optional[str] = None,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None, aws_session_token: Optional[str] = None,
): ):
""" """
Initialize the Neptune Analytics hybrid adapter. Initialize the Neptune Analytics hybrid adapter.
@ -74,12 +78,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
region=region, region=region,
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token aws_session_token=aws_session_token,
) )
# Add vector-specific attributes # Add vector-specific attributes
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
logger.info(f"Initialized Neptune Analytics hybrid adapter for graph: \"{graph_id}\" in region: \"{self.region}\"") logger.info(
f'Initialized Neptune Analytics hybrid adapter for graph: "{graph_id}" in region: "{self.region}"'
)
# VectorDBInterface methods implementation # VectorDBInterface methods implementation
@ -161,7 +167,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
# Fetch embeddings # Fetch embeddings
texts = [DataPoint.get_embeddable_data(t) for t in data_points] texts = [DataPoint.get_embeddable_data(t) for t in data_points]
data_vectors = (await self.embedding_engine.embed_text(texts)) data_vectors = await self.embedding_engine.embed_text(texts)
for index, data_point in enumerate(data_points): for index, data_point in enumerate(data_points):
node_id = data_point.id node_id = data_point.id
@ -172,23 +178,24 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
properties = self._serialize_properties(data_point.model_dump()) properties = self._serialize_properties(data_point.model_dump())
properties[self._COLLECTION_PREFIX] = collection_name properties[self._COLLECTION_PREFIX] = collection_name
params = dict( params = dict(
node_id = str(node_id), node_id=str(node_id),
properties = properties, properties=properties,
embedding = data_vector, embedding=data_vector,
collection_name = collection_name collection_name=collection_name,
) )
# Compose the query and send # Compose the query and send
query_string = ( query_string = (
f"MERGE (n " f"MERGE (n "
f":{self._VECTOR_NODE_LABEL} " f":{self._VECTOR_NODE_LABEL} "
f" {{`~id`: $node_id}}) " f" {{`~id`: $node_id}}) "
f"ON CREATE SET n = $properties, n.updated_at = timestamp() " f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
f"ON MATCH SET n += $properties, n.updated_at = timestamp() " f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
f"WITH n, $embedding AS embedding " f"WITH n, $embedding AS embedding "
f"CALL neptune.algo.vectors.upsert(n, embedding) " f"CALL neptune.algo.vectors.upsert(n, embedding) "
f"YIELD success " f"YIELD success "
f"RETURN success ") f"RETURN success "
)
try: try:
self._client.query(query_string, params) self._client.query(query_string, params)
@ -208,10 +215,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
""" """
# Do the fetch for each node # Do the fetch for each node
params = dict(node_ids=data_point_ids, collection_name=collection_name) params = dict(node_ids=data_point_ids, collection_name=collection_name)
query_string = (f"MATCH( n :{self._VECTOR_NODE_LABEL}) " query_string = (
f"WHERE id(n) in $node_ids AND " f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
f"n.{self._COLLECTION_PREFIX} = $collection_name " f"WHERE id(n) in $node_ids AND "
f"RETURN n as payload ") f"n.{self._COLLECTION_PREFIX} = $collection_name "
f"RETURN n as payload "
)
try: try:
result = self._client.query(query_string, params) result = self._client.query(query_string, params)
@ -259,7 +268,8 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND: if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
logger.warning( logger.warning(
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). " "Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
"Defaulting to limit=10.", limit "Defaulting to limit=10.",
limit,
) )
limit = self._TOPK_UPPER_BOUND limit = self._TOPK_UPPER_BOUND
@ -272,7 +282,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
elif query_vector: elif query_vector:
embedding = query_vector embedding = query_vector
else: else:
data_vectors = (await self.embedding_engine.embed_text([query_text])) data_vectors = await self.embedding_engine.embed_text([query_text])
embedding = data_vectors[0] embedding = data_vectors[0]
# Compose the parameters map # Compose the parameters map
@ -305,9 +315,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
try: try:
query_response = self._client.query(query_string, params) query_response = self._client.query(query_string, params)
return [self._get_scored_result( return [self._get_scored_result(item=item, with_score=True) for item in query_response]
item = item, with_score = True
) for item in query_response]
except Exception as e: except Exception as e:
self._na_exception_handler(e, query_string) self._na_exception_handler(e, query_string)
@ -332,11 +340,13 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
self._validate_embedding_engine() self._validate_embedding_engine()
# Convert text to embedding array in batch # Convert text to embedding array in batch
data_vectors = (await self.embedding_engine.embed_text(query_texts)) data_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(*[ return await asyncio.gather(
self.search(collection_name, None, vector, limit, with_vectors) *[
for vector in data_vectors self.search(collection_name, None, vector, limit, with_vectors)
]) for vector in data_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
""" """
@ -350,10 +360,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
- data_point_ids (list[str]): A list of IDs of the data points to delete. - data_point_ids (list[str]): A list of IDs of the data points to delete.
""" """
params = dict(node_ids=data_point_ids, collection_name=collection_name) params = dict(node_ids=data_point_ids, collection_name=collection_name)
query_string = (f"MATCH (n :{self._VECTOR_NODE_LABEL}) " query_string = (
f"WHERE id(n) IN $node_ids " f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
f"AND n.{self._COLLECTION_PREFIX} = $collection_name " f"WHERE id(n) IN $node_ids "
f"DETACH DELETE n") f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
f"DETACH DELETE n"
)
try: try:
self._client.query(query_string, params) self._client.query(query_string, params)
except Exception as e: except Exception as e:
@ -370,7 +382,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
await self.create_collection(f"{index_name}_{index_property_name}") await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points( async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint] self, index_name: str, index_property_name: str, data_points: list[DataPoint]
): ):
""" """
Indexes a list of data points into Neptune Analytics by creating them as nodes. Indexes a list of data points into Neptune Analytics by creating them as nodes.
@ -402,29 +414,28 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
Remove obsolete or unnecessary data from the database. Remove obsolete or unnecessary data from the database.
""" """
# Run actual truncate # Run actual truncate
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) " self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n")
f"DETACH DELETE n")
pass pass
@staticmethod @staticmethod
def _get_scored_result(item: dict, with_vector: bool = False, with_score: bool = False) -> ScoredResult: def _get_scored_result(
item: dict, with_vector: bool = False, with_score: bool = False
) -> ScoredResult:
""" """
Util method to simplify the object creation of ScoredResult base on incoming NX payload response. Util method to simplify the object creation of ScoredResult base on incoming NX payload response.
""" """
return ScoredResult( return ScoredResult(
id=item.get('payload').get('~id'), id=item.get("payload").get("~id"),
payload=item.get('payload').get('~properties'), payload=item.get("payload").get("~properties"),
score=item.get('score') if with_score else 0, score=item.get("score") if with_score else 0,
vector=item.get('embedding') if with_vector else None vector=item.get("embedding") if with_vector else None,
) )
def _na_exception_handler(self, ex, query_string: str): def _na_exception_handler(self, ex, query_string: str):
""" """
Generic exception handler for NA langchain. Generic exception handler for NA langchain.
""" """
logger.error( logger.error("Neptune Analytics query failed: %s | Query: [%s]", ex, query_string)
"Neptune Analytics query failed: %s | Query: [%s]", ex, query_string
)
raise ex raise ex
def _validate_embedding_engine(self): def _validate_embedding_engine(self):
@ -433,4 +444,6 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
:raises: ValueError if this object does not have a valid embedding_engine :raises: ValueError if this object does not have a valid embedding_engine
""" """
if self.embedding_engine is None: if self.embedding_engine is None:
raise ValueError("Neptune Analytics requires an embedder defined to make vector operations") raise ValueError(
"Neptune Analytics requires an embedder defined to make vector operations"
)

View file

@ -125,9 +125,15 @@ def create_vector_engine(
if not vector_db_url: if not vector_db_url:
raise EnvironmentError("Missing Neptune endpoint.") raise EnvironmentError("Missing Neptune endpoint.")
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
NEPTUNE_ANALYTICS_ENDPOINT_URL,
)
if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL): if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'") raise ValueError(
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
)
graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "") graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")

View file

@ -8,7 +8,7 @@ from cognee.modules.data.processing.document_types import TextDocument
# Set up Amazon credentials in .env file and get the values from environment variables # Set up Amazon credentials in .env file and get the values from environment variables
load_dotenv() load_dotenv()
graph_id = os.getenv('GRAPH_ID', "") graph_id = os.getenv("GRAPH_ID", "")
na_adapter = NeptuneGraphDB(graph_id) na_adapter = NeptuneGraphDB(graph_id)
@ -24,33 +24,33 @@ def setup():
# stored in Amazon S3. # stored in Amazon S3.
document = TextDocument( document = TextDocument(
name='text_test.txt', name="text_test.txt",
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text_test.txt', raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt",
external_metadata='{}', external_metadata="{}",
mime_type='text/plain' mime_type="text/plain",
) )
document_chunk = DocumentChunk( document_chunk = DocumentChunk(
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
chunk_size=187, chunk_size=187,
chunk_index=0, chunk_index=0,
cut_type='paragraph_end', cut_type="paragraph_end",
is_part_of=document, is_part_of=document,
) )
graph_database = EntityType(name='graph database', description='graph database') graph_database = EntityType(name="graph database", description="graph database")
neptune_analytics_entity = Entity( neptune_analytics_entity = Entity(
name='neptune analytics', name="neptune analytics",
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.', description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
) )
neptune_database_entity = Entity( neptune_database_entity = Entity(
name='amazon neptune database', name="amazon neptune database",
description='A popular managed graph database that complements Neptune Analytics.', description="A popular managed graph database that complements Neptune Analytics.",
) )
storage = EntityType(name='storage', description='storage') storage = EntityType(name="storage", description="storage")
storage_entity = Entity( storage_entity = Entity(
name='amazon s3', name="amazon s3",
description='A storage service provided by Amazon Web Services that allows storing graph data.', description="A storage service provided by Amazon Web Services that allows storing graph data.",
) )
nodes_data = [ nodes_data = [
@ -67,37 +67,37 @@ def setup():
( (
str(document_chunk.id), str(document_chunk.id),
str(storage_entity.id), str(storage_entity.id),
'contains', "contains",
), ),
( (
str(storage_entity.id), str(storage_entity.id),
str(storage.id), str(storage.id),
'is_a', "is_a",
), ),
( (
str(document_chunk.id), str(document_chunk.id),
str(neptune_database_entity.id), str(neptune_database_entity.id),
'contains', "contains",
), ),
( (
str(neptune_database_entity.id), str(neptune_database_entity.id),
str(graph_database.id), str(graph_database.id),
'is_a', "is_a",
), ),
( (
str(document_chunk.id), str(document_chunk.id),
str(document.id), str(document.id),
'is_part_of', "is_part_of",
), ),
( (
str(document_chunk.id), str(document_chunk.id),
str(neptune_analytics_entity.id), str(neptune_analytics_entity.id),
'contains', "contains",
), ),
( (
str(neptune_analytics_entity.id), str(neptune_analytics_entity.id),
str(graph_database.id), str(graph_database.id),
'is_a', "is_a",
), ),
] ]
@ -155,42 +155,44 @@ async def pipeline_method():
print("------NEIGHBORING NODES-------") print("------NEIGHBORING NODES-------")
center_node = nodes[2] center_node = nodes[2]
neighbors = await na_adapter.get_neighbors(str(center_node.id)) neighbors = await na_adapter.get_neighbors(str(center_node.id))
print(f"found {len(neighbors)} neighbors for node \"{center_node.name}\"") print(f'found {len(neighbors)} neighbors for node "{center_node.name}"')
for neighbor in neighbors: for neighbor in neighbors:
print(neighbor) print(neighbor)
print("------NEIGHBORING EDGES-------") print("------NEIGHBORING EDGES-------")
center_node = nodes[2] center_node = nodes[2]
neighbouring_edges = await na_adapter.get_edges(str(center_node.id)) neighbouring_edges = await na_adapter.get_edges(str(center_node.id))
print(f"found {len(neighbouring_edges)} edges neighbouring node \"{center_node.name}\"") print(f'found {len(neighbouring_edges)} edges neighbouring node "{center_node.name}"')
for edge in neighbouring_edges: for edge in neighbouring_edges:
print(edge) print(edge)
print("------GET CONNECTIONS (SOURCE NODE)-------") print("------GET CONNECTIONS (SOURCE NODE)-------")
document_chunk_node = nodes[0] document_chunk_node = nodes[0]
connections = await na_adapter.get_connections(str(document_chunk_node.id)) connections = await na_adapter.get_connections(str(document_chunk_node.id))
print(f"found {len(connections)} connections for node \"{document_chunk_node.type}\"") print(f'found {len(connections)} connections for node "{document_chunk_node.type}"')
for connection in connections: for connection in connections:
src, relationship, tgt = connection src, relationship, tgt = connection
src = src.get("name", src.get("type", "unknown")) src = src.get("name", src.get("type", "unknown"))
relationship = relationship["relationship_name"] relationship = relationship["relationship_name"]
tgt = tgt.get("name", tgt.get("type", "unknown")) tgt = tgt.get("name", tgt.get("type", "unknown"))
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"") print(f'"{src}"-[{relationship}]->"{tgt}"')
print("------GET CONNECTIONS (TARGET NODE)-------") print("------GET CONNECTIONS (TARGET NODE)-------")
connections = await na_adapter.get_connections(str(center_node.id)) connections = await na_adapter.get_connections(str(center_node.id))
print(f"found {len(connections)} connections for node \"{center_node.name}\"") print(f'found {len(connections)} connections for node "{center_node.name}"')
for connection in connections: for connection in connections:
src, relationship, tgt = connection src, relationship, tgt = connection
src = src.get("name", src.get("type", "unknown")) src = src.get("name", src.get("type", "unknown"))
relationship = relationship["relationship_name"] relationship = relationship["relationship_name"]
tgt = tgt.get("name", tgt.get("type", "unknown")) tgt = tgt.get("name", tgt.get("type", "unknown"))
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"") print(f'"{src}"-[{relationship}]->"{tgt}"')
print("------SUBGRAPH-------") print("------SUBGRAPH-------")
node_names = ["neptune analytics", "amazon neptune database"] node_names = ["neptune analytics", "amazon neptune database"]
subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names) subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names)
print(f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}") print(
f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}"
)
for subgraph_node in subgraph_nodes: for subgraph_node in subgraph_nodes:
print(subgraph_node) print(subgraph_node)
for subgraph_edge in subgraph_edges: for subgraph_edge in subgraph_edges:
@ -199,17 +201,17 @@ async def pipeline_method():
print("------STAT-------") print("------STAT-------")
stat = await na_adapter.get_graph_metrics(include_optional=True) stat = await na_adapter.get_graph_metrics(include_optional=True)
assert type(stat) is dict assert type(stat) is dict
assert stat['num_nodes'] == 7 assert stat["num_nodes"] == 7
assert stat['num_edges'] == 7 assert stat["num_edges"] == 7
assert stat['mean_degree'] == 2.0 assert stat["mean_degree"] == 2.0
assert round(stat['edge_density'], 3) == 0.167 assert round(stat["edge_density"], 3) == 0.167
assert stat['num_connected_components'] == [7] assert stat["num_connected_components"] == [7]
assert stat['sizes_of_connected_components'] == 1 assert stat["sizes_of_connected_components"] == 1
assert stat['num_selfloops'] == 0 assert stat["num_selfloops"] == 0
# Unsupported optional metrics # Unsupported optional metrics
assert stat['diameter'] == -1 assert stat["diameter"] == -1
assert stat['avg_shortest_path_length'] == -1 assert stat["avg_shortest_path_length"] == -1
assert stat['avg_clustering'] == -1 assert stat["avg_clustering"] == -1
print("------DELETE-------") print("------DELETE-------")
# delete all nodes and edges: # delete all nodes and edges:
@ -253,7 +255,9 @@ async def misc_methods():
print(edge_labels) print(edge_labels)
print("------Get Filtered Graph-------") print("------Get Filtered Graph-------")
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data([{'name': ['text_test.txt']}]) filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data(
[{"name": ["text_test.txt"]}]
)
print(filtered_nodes, filtered_edges) print(filtered_nodes, filtered_edges)
print("------Get Degree one nodes-------") print("------Get Degree one nodes-------")
@ -261,15 +265,13 @@ async def misc_methods():
print(degree_one_nodes) print(degree_one_nodes)
print("------Get Doc sub-graph-------") print("------Get Doc sub-graph-------")
doc_sub_graph = await na_adapter.get_document_subgraph('test.txt') doc_sub_graph = await na_adapter.get_document_subgraph("test.txt")
print(doc_sub_graph) print(doc_sub_graph)
print("------Fetch and Remove connections (Predecessors)-------") print("------Fetch and Remove connections (Predecessors)-------")
# Fetch test edge # Fetch test edge
(src_id, dest_id, relationship) = edges[0] (src_id, dest_id, relationship) = edges[0]
nodes_predecessors = await na_adapter.get_predecessors( nodes_predecessors = await na_adapter.get_predecessors(node_id=dest_id, edge_label=relationship)
node_id=dest_id, edge_label=relationship
)
assert len(nodes_predecessors) > 0 assert len(nodes_predecessors) > 0
await na_adapter.remove_connection_to_predecessors_of( await na_adapter.remove_connection_to_predecessors_of(
@ -281,25 +283,19 @@ async def misc_methods():
# Return empty after relationship being deleted. # Return empty after relationship being deleted.
assert len(nodes_predecessors_after) == 0 assert len(nodes_predecessors_after) == 0
print("------Fetch and Remove connections (Successors)-------") print("------Fetch and Remove connections (Successors)-------")
_, edges_suc = await na_adapter.get_graph_data() _, edges_suc = await na_adapter.get_graph_data()
(src_id, dest_id, relationship, _) = edges_suc[0] (src_id, dest_id, relationship, _) = edges_suc[0]
nodes_successors = await na_adapter.get_successors( nodes_successors = await na_adapter.get_successors(node_id=src_id, edge_label=relationship)
node_id=src_id, edge_label=relationship
)
assert len(nodes_successors) > 0 assert len(nodes_successors) > 0
await na_adapter.remove_connection_to_successors_of( await na_adapter.remove_connection_to_successors_of(node_ids=[dest_id], edge_label=relationship)
node_ids=[dest_id], edge_label=relationship
)
nodes_successors_after = await na_adapter.get_successors( nodes_successors_after = await na_adapter.get_successors(
node_id=src_id, edge_label=relationship node_id=src_id, edge_label=relationship
) )
assert len(nodes_successors_after) == 0 assert len(nodes_successors_after) == 0
# no-op # no-op
await na_adapter.project_entire_graph() await na_adapter.project_entire_graph()
await na_adapter.drop_graph() await na_adapter.drop_graph()

View file

@ -8,11 +8,13 @@ from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.data.processing.document_types import TextDocument
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
)
# Set up Amazon credentials in .env file and get the values from environment variables # Set up Amazon credentials in .env file and get the values from environment variables
load_dotenv() load_dotenv()
graph_id = os.getenv('GRAPH_ID', "") graph_id = os.getenv("GRAPH_ID", "")
# get the default embedder # get the default embedder
embedding_engine = get_embedding_engine() embedding_engine = get_embedding_engine()
@ -23,6 +25,7 @@ collection = "test_collection"
logger = get_logger("test_neptune_analytics_hybrid") logger = get_logger("test_neptune_analytics_hybrid")
def setup_data(): def setup_data():
# Define nodes data before the main function # Define nodes data before the main function
# These nodes were defined using openAI from the following prompt: # These nodes were defined using openAI from the following prompt:
@ -34,33 +37,33 @@ def setup_data():
# stored in Amazon S3. # stored in Amazon S3.
document = TextDocument( document = TextDocument(
name='text.txt', name="text.txt",
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text.txt', raw_data_location="git/cognee/examples/database_examples/data_storage/data/text.txt",
external_metadata='{}', external_metadata="{}",
mime_type='text/plain' mime_type="text/plain",
) )
document_chunk = DocumentChunk( document_chunk = DocumentChunk(
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
chunk_size=187, chunk_size=187,
chunk_index=0, chunk_index=0,
cut_type='paragraph_end', cut_type="paragraph_end",
is_part_of=document, is_part_of=document,
) )
graph_database = EntityType(name='graph database', description='graph database') graph_database = EntityType(name="graph database", description="graph database")
neptune_analytics_entity = Entity( neptune_analytics_entity = Entity(
name='neptune analytics', name="neptune analytics",
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.', description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
) )
neptune_database_entity = Entity( neptune_database_entity = Entity(
name='amazon neptune database', name="amazon neptune database",
description='A popular managed graph database that complements Neptune Analytics.', description="A popular managed graph database that complements Neptune Analytics.",
) )
storage = EntityType(name='storage', description='storage') storage = EntityType(name="storage", description="storage")
storage_entity = Entity( storage_entity = Entity(
name='amazon s3', name="amazon s3",
description='A storage service provided by Amazon Web Services that allows storing graph data.', description="A storage service provided by Amazon Web Services that allows storing graph data.",
) )
nodes_data = [ nodes_data = [
@ -77,41 +80,42 @@ def setup_data():
( (
str(document_chunk.id), str(document_chunk.id),
str(storage_entity.id), str(storage_entity.id),
'contains', "contains",
), ),
( (
str(storage_entity.id), str(storage_entity.id),
str(storage.id), str(storage.id),
'is_a', "is_a",
), ),
( (
str(document_chunk.id), str(document_chunk.id),
str(neptune_database_entity.id), str(neptune_database_entity.id),
'contains', "contains",
), ),
( (
str(neptune_database_entity.id), str(neptune_database_entity.id),
str(graph_database.id), str(graph_database.id),
'is_a', "is_a",
), ),
( (
str(document_chunk.id), str(document_chunk.id),
str(document.id), str(document.id),
'is_part_of', "is_part_of",
), ),
( (
str(document_chunk.id), str(document_chunk.id),
str(neptune_analytics_entity.id), str(neptune_analytics_entity.id),
'contains', "contains",
), ),
( (
str(neptune_analytics_entity.id), str(neptune_analytics_entity.id),
str(graph_database.id), str(graph_database.id),
'is_a', "is_a",
), ),
] ]
return nodes_data, edges_data return nodes_data, edges_data
async def test_add_graph_then_vector_data(): async def test_add_graph_then_vector_data():
logger.info("------test_add_graph_then_vector_data-------") logger.info("------test_add_graph_then_vector_data-------")
(nodes, edges) = setup_data() (nodes, edges) = setup_data()
@ -134,6 +138,7 @@ async def test_add_graph_then_vector_data():
assert len(edges) == 0 assert len(edges) == 0
logger.info("------PASSED-------") logger.info("------PASSED-------")
async def test_add_vector_then_node_data(): async def test_add_vector_then_node_data():
logger.info("------test_add_vector_then_node_data-------") logger.info("------test_add_vector_then_node_data-------")
(nodes, edges) = setup_data() (nodes, edges) = setup_data()
@ -156,6 +161,7 @@ async def test_add_vector_then_node_data():
assert len(edges) == 0 assert len(edges) == 0
logger.info("------PASSED-------") logger.info("------PASSED-------")
def main(): def main():
""" """
Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data
@ -165,5 +171,6 @@ def main():
asyncio.run(test_add_graph_then_vector_data()) asyncio.run(test_add_graph_then_vector_data())
asyncio.run(test_add_vector_then_node_data()) asyncio.run(test_add_vector_then_node_data())
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -8,13 +8,16 @@ from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, IndexSchema from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
NeptuneAnalyticsAdapter,
IndexSchema,
)
logger = get_logger() logger = get_logger()
async def main(): async def main():
graph_id = os.getenv('GRAPH_ID', "") graph_id = os.getenv("GRAPH_ID", "")
cognee.config.set_vector_db_provider("neptune_analytics") cognee.config.set_vector_db_provider("neptune_analytics")
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}") cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
data_directory_path = str( data_directory_path = str(
@ -87,6 +90,7 @@ async def main():
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
async def vector_backend_api_test(): async def vector_backend_api_test():
cognee.config.set_vector_db_provider("neptune_analytics") cognee.config.set_vector_db_provider("neptune_analytics")
@ -101,7 +105,7 @@ async def vector_backend_api_test():
get_vector_engine() get_vector_engine()
# Return a valid engine object with valid URL. # Return a valid engine object with valid URL.
graph_id = os.getenv('GRAPH_ID', "") graph_id = os.getenv("GRAPH_ID", "")
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}") cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
engine = get_vector_engine() engine = get_vector_engine()
assert isinstance(engine, NeptuneAnalyticsAdapter) assert isinstance(engine, NeptuneAnalyticsAdapter)
@ -133,28 +137,22 @@ async def vector_backend_api_test():
query_text=TEST_TEXT, query_text=TEST_TEXT,
query_vector=None, query_vector=None,
limit=10, limit=10,
with_vector=True) with_vector=True,
assert (len(result_search) == 2) )
assert len(result_search) == 2
# # Retrieve data-points # # Retrieve data-points
result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2]) result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
assert any( assert any(str(r.id) == TEST_UUID and r.payload["text"] == TEST_TEXT for r in result)
str(r.id) == TEST_UUID and r.payload['text'] == TEST_TEXT assert any(str(r.id) == TEST_UUID_2 and r.payload["text"] == TEST_TEXT_2 for r in result)
for r in result
)
assert any(
str(r.id) == TEST_UUID_2 and r.payload['text'] == TEST_TEXT_2
for r in result
)
# Search multiple # Search multiple
result_search_batch = await engine.batch_search( result_search_batch = await engine.batch_search(
collection_name=TEST_COLLECTION_NAME, collection_name=TEST_COLLECTION_NAME,
query_texts=[TEST_TEXT, TEST_TEXT_2], query_texts=[TEST_TEXT, TEST_TEXT_2],
limit=10, limit=10,
with_vectors=False with_vectors=False,
) )
assert (len(result_search_batch) == 2 and assert len(result_search_batch) == 2 and all(len(batch) == 2 for batch in result_search_batch)
all(len(batch) == 2 for batch in result_search_batch))
# Delete datapoint from vector store # Delete datapoint from vector store
await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2]) await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
@ -163,6 +161,7 @@ async def vector_backend_api_test():
result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID]) result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID])
assert result_deleted == [] assert result_deleted == []
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View file

@ -9,6 +9,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
async def main(): async def main():
""" """
Example script demonstrating how to use Cognee with Amazon Neptune Analytics Example script demonstrating how to use Cognee with Amazon Neptune Analytics
@ -22,7 +23,7 @@ async def main():
""" """
# Set up Amazon credentials in .env file and get the values from environment variables # Set up Amazon credentials in .env file and get the values from environment variables
graph_endpoint_url = "neptune-graph://" + os.getenv('GRAPH_ID', "") graph_endpoint_url = "neptune-graph://" + os.getenv("GRAPH_ID", "")
# Configure Neptune Analytics as the graph & vector database provider # Configure Neptune Analytics as the graph & vector database provider
cognee.config.set_graph_db_config( cognee.config.set_graph_db_config(
@ -77,7 +78,9 @@ async def main():
# Now let's perform some searches # Now let's perform some searches
# 1. Search for insights related to "Neptune Analytics" # 1. Search for insights related to "Neptune Analytics"
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neptune Analytics") insights_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text="Neptune Analytics"
)
print("\n========Insights about Neptune Analytics========:") print("\n========Insights about Neptune Analytics========:")
for result in insights_results: for result in insights_results:
print(f"- {result}") print(f"- {result}")

File diff suppressed because one or more lines are too long