fixes to formatting
This commit is contained in:
parent
7d2bf78c81
commit
11422f362f
10 changed files with 353 additions and 324 deletions
|
|
@ -149,7 +149,9 @@ def create_graph_engine(
|
|||
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
|
||||
|
||||
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
|
||||
raise ValueError(f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>")
|
||||
raise ValueError(
|
||||
f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>"
|
||||
)
|
||||
|
||||
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
|
||||
|
||||
|
|
@ -174,10 +176,15 @@ def create_graph_engine(
|
|||
if not graph_database_url:
|
||||
raise EnvironmentError("Missing Neptune endpoint.")
|
||||
|
||||
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
|
||||
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||
NeptuneAnalyticsAdapter,
|
||||
NEPTUNE_ANALYTICS_ENDPOINT_URL,
|
||||
)
|
||||
|
||||
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
||||
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
|
||||
raise ValueError(
|
||||
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
|
||||
)
|
||||
|
||||
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ logger = get_logger("NeptuneGraphDB")
|
|||
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
|
||||
LANGCHAIN_AWS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.")
|
||||
|
|
@ -36,11 +37,13 @@ except ImportError:
|
|||
|
||||
NEPTUNE_ENDPOINT_URL = "neptune-graph://"
|
||||
|
||||
|
||||
class NeptuneGraphDB(GraphDBInterface):
|
||||
"""
|
||||
Adapter for interacting with Amazon Neptune Analytics graph store.
|
||||
This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library.
|
||||
"""
|
||||
|
||||
_GRAPH_NODE_LABEL = "COGNEE_NODE"
|
||||
|
||||
def __init__(
|
||||
|
|
@ -61,28 +64,30 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
- aws_access_key_id (Optional[str]): AWS access key ID
|
||||
- aws_secret_access_key (Optional[str]): AWS secret access key
|
||||
- aws_session_token (Optional[str]): AWS session token for temporary credentials
|
||||
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- NeptuneAnalyticsConfigurationError: If configuration parameters are invalid
|
||||
"""
|
||||
# validate import
|
||||
if not LANGCHAIN_AWS_AVAILABLE:
|
||||
raise ImportError("langchain_aws is not available. Please install it to use Neptune Analytics.")
|
||||
raise ImportError(
|
||||
"langchain_aws is not available. Please install it to use Neptune Analytics."
|
||||
)
|
||||
|
||||
# Validate configuration
|
||||
if not validate_graph_id(graph_id):
|
||||
raise NeptuneAnalyticsConfigurationError(message=f"Invalid graph ID: \"{graph_id}\"")
|
||||
|
||||
raise NeptuneAnalyticsConfigurationError(message=f'Invalid graph ID: "{graph_id}"')
|
||||
|
||||
if region and not validate_aws_region(region):
|
||||
raise NeptuneAnalyticsConfigurationError(message=f"Invalid AWS region: \"{region}\"")
|
||||
|
||||
raise NeptuneAnalyticsConfigurationError(message=f'Invalid AWS region: "{region}"')
|
||||
|
||||
self.graph_id = graph_id
|
||||
self.region = region
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
self.aws_session_token = aws_session_token
|
||||
|
||||
|
||||
# Build configuration
|
||||
self.config = build_neptune_config(
|
||||
graph_id=self.graph_id,
|
||||
|
|
@ -91,15 +96,17 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
)
|
||||
|
||||
|
||||
# Initialize Neptune Analytics client using langchain_aws
|
||||
self._client: NeptuneAnalyticsGraph = self._initialize_client()
|
||||
logger.info(f"Initialized Neptune Analytics adapter for graph: \"{graph_id}\" in region: \"{self.region}\"")
|
||||
logger.info(
|
||||
f'Initialized Neptune Analytics adapter for graph: "{graph_id}" in region: "{self.region}"'
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]:
|
||||
"""
|
||||
Initialize the Neptune Analytics client using langchain_aws.
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Optional[Any]: The Neptune Analytics client or None if not available
|
||||
|
|
@ -108,9 +115,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
# Initialize the Neptune Analytics Graph client
|
||||
client_config = {
|
||||
"graph_identifier": self.graph_id,
|
||||
"config": Config(
|
||||
user_agent_appid='Cognee'
|
||||
)
|
||||
"config": Config(user_agent_appid="Cognee"),
|
||||
}
|
||||
# Add AWS credentials if provided
|
||||
if self.region:
|
||||
|
|
@ -121,13 +126,15 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
client_config["aws_secret_access_key"] = self.aws_secret_access_key
|
||||
if self.aws_session_token:
|
||||
client_config["aws_session_token"] = self.aws_session_token
|
||||
|
||||
|
||||
client = NeptuneAnalyticsGraph(**client_config)
|
||||
logger.info("Successfully initialized Neptune Analytics client")
|
||||
return client
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise NeptuneAnalyticsConfigurationError(message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}") from e
|
||||
raise NeptuneAnalyticsConfigurationError(
|
||||
message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}"
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
|
@ -175,7 +182,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
params = {}
|
||||
logger.debug(f"executing na query:\nquery={query}\n")
|
||||
result = self._client.query(query, params)
|
||||
|
||||
|
||||
# Convert the result to list format expected by the interface
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
|
|
@ -183,7 +190,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
return [result]
|
||||
else:
|
||||
return [{"result": result}]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Neptune Analytics query failed: {error_msg}")
|
||||
|
|
@ -212,10 +219,11 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
"node_id": str(node.id),
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
|
||||
result = await self.query(query, params)
|
||||
logger.debug(f"Successfully added/updated node: {node.id}")
|
||||
|
||||
logger.debug(f"Successfully added/updated node: {str(result)}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Failed to add node {node.id}: {error_msg}")
|
||||
|
|
@ -256,7 +264,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
}
|
||||
result = await self.query(query, params)
|
||||
|
||||
processed_count = result[0].get('nodes_processed', 0) if result else 0
|
||||
processed_count = result[0].get("nodes_processed", 0) if result else 0
|
||||
logger.debug(f"Successfully processed {processed_count} nodes in bulk operation")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -268,7 +276,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
try:
|
||||
await self.add_node(node)
|
||||
except Exception as node_error:
|
||||
logger.error(f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}")
|
||||
logger.error(
|
||||
f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}"
|
||||
)
|
||||
continue
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
|
|
@ -287,13 +297,11 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
DETACH DELETE n
|
||||
"""
|
||||
|
||||
params = {
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
params = {"node_id": node_id}
|
||||
|
||||
await self.query(query, params)
|
||||
logger.debug(f"Successfully deleted node: {node_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Failed to delete node {node_id}: {error_msg}")
|
||||
|
|
@ -333,7 +341,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
try:
|
||||
await self.delete_node(node_id)
|
||||
except Exception as node_error:
|
||||
logger.error(f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}")
|
||||
logger.error(
|
||||
f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}"
|
||||
)
|
||||
continue
|
||||
|
||||
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
||||
|
|
@ -355,10 +365,10 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
WHERE id(n) = $node_id
|
||||
RETURN n
|
||||
"""
|
||||
params = {'node_id': node_id}
|
||||
params = {"node_id": node_id}
|
||||
|
||||
result = await self.query(query, params)
|
||||
|
||||
|
||||
if result and len(result) == 1:
|
||||
# Extract node properties from the result
|
||||
logger.debug(f"Successfully retrieved node: {node_id}")
|
||||
|
|
@ -369,7 +379,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
elif len(result) > 1:
|
||||
logger.debug(f"Only one node expected, multiple returned: {node_id}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Failed to get node {node_id}: {error_msg}")
|
||||
|
|
@ -406,7 +416,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
# Extract node data from results
|
||||
nodes = [record["n"] for record in result]
|
||||
|
||||
logger.debug(f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested")
|
||||
logger.debug(
|
||||
f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested"
|
||||
)
|
||||
return nodes
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -421,11 +433,12 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
if node_data:
|
||||
nodes.append(node_data)
|
||||
except Exception as node_error:
|
||||
logger.error(f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}")
|
||||
logger.error(
|
||||
f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}"
|
||||
)
|
||||
continue
|
||||
return nodes
|
||||
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
"""
|
||||
Retrieve a single node based on its ID.
|
||||
|
|
@ -512,8 +525,10 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
"properties": serialized_properties,
|
||||
}
|
||||
await self.query(query, params)
|
||||
logger.debug(f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}")
|
||||
|
||||
logger.debug(
|
||||
f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}")
|
||||
|
|
@ -557,20 +572,24 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
"""
|
||||
|
||||
# Prepare edges data for bulk operation
|
||||
params = {"edges":
|
||||
[
|
||||
params = {
|
||||
"edges": [
|
||||
{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": self._serialize_properties(edge[3] if len(edge) > 3 and edge[3] else {}),
|
||||
"properties": self._serialize_properties(
|
||||
edge[3] if len(edge) > 3 and edge[3] else {}
|
||||
),
|
||||
}
|
||||
for edge in edges_for_relationship
|
||||
]
|
||||
}
|
||||
results[relationship_name] = await self.query(query, params)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}")
|
||||
logger.error(
|
||||
f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}"
|
||||
)
|
||||
logger.info("Falling back to individual edge creation")
|
||||
for edge in edges_by_relationship:
|
||||
try:
|
||||
|
|
@ -578,15 +597,16 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
properties = edge[3] if len(edge) > 3 else {}
|
||||
await self.add_edge(source_id, target_id, relationship_name, properties)
|
||||
except Exception as edge_error:
|
||||
logger.error(f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}")
|
||||
logger.error(
|
||||
f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}"
|
||||
)
|
||||
continue
|
||||
|
||||
processed_count = 0
|
||||
for result in results.values():
|
||||
processed_count += result[0].get('edges_processed', 0) if result else 0
|
||||
processed_count += result[0].get("edges_processed", 0) if result else 0
|
||||
logger.debug(f"Successfully processed {processed_count} edges in bulk operation")
|
||||
|
||||
|
||||
async def delete_graph(self) -> None:
|
||||
"""
|
||||
Delete all nodes and edges from the graph database.
|
||||
|
|
@ -599,14 +619,12 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
# Build openCypher query to delete the graph
|
||||
query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n"
|
||||
await self.query(query)
|
||||
logger.debug(f"Successfully deleted all edges and nodes from the graph")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Failed to delete graph: {error_msg}")
|
||||
raise Exception(f"Failed to delete graph: {error_msg}") from e
|
||||
|
||||
|
||||
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
||||
"""
|
||||
Retrieve all nodes and edges within the graph.
|
||||
|
|
@ -633,13 +651,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
edges_result = await self.query(edges_query)
|
||||
|
||||
# Format nodes as (node_id, properties) tuples
|
||||
nodes = [
|
||||
(
|
||||
result["node_id"],
|
||||
result["properties"]
|
||||
)
|
||||
for result in nodes_result
|
||||
]
|
||||
nodes = [(result["node_id"], result["properties"]) for result in nodes_result]
|
||||
|
||||
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
||||
edges = [
|
||||
|
|
@ -647,7 +659,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
result["source_id"],
|
||||
result["target_id"],
|
||||
result["relationship_name"],
|
||||
result["properties"]
|
||||
result["properties"],
|
||||
)
|
||||
for result in edges_result
|
||||
]
|
||||
|
|
@ -679,9 +691,11 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
"num_nodes": num_nodes,
|
||||
"num_edges": num_edges,
|
||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
||||
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1)) if num_nodes != 0 else None,
|
||||
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1))
|
||||
if num_nodes != 0
|
||||
else None,
|
||||
"num_connected_components": num_cluster,
|
||||
"sizes_of_connected_components": list_clsuter_size
|
||||
"sizes_of_connected_components": list_clsuter_size,
|
||||
}
|
||||
|
||||
optional_metrics = {
|
||||
|
|
@ -692,7 +706,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
}
|
||||
|
||||
if include_optional:
|
||||
optional_metrics['num_selfloops'] = await self._count_self_loops()
|
||||
optional_metrics["num_selfloops"] = await self._count_self_loops()
|
||||
# Unsupported due to long-running queries when computing the shortest path for each node in the graph:
|
||||
# optional_metrics['diameter']
|
||||
# optional_metrics['avg_shortest_path_length']
|
||||
|
|
@ -728,17 +742,19 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
"source_id": source_id,
|
||||
"target_id": target_id,
|
||||
}
|
||||
|
||||
|
||||
result = await self.query(query, params)
|
||||
|
||||
|
||||
if result and len(result) > 0:
|
||||
edge_exists = result.pop().get('edge_exists', False)
|
||||
logger.debug(f"Edge existence check for "
|
||||
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}")
|
||||
edge_exists = result.pop().get("edge_exists", False)
|
||||
logger.debug(
|
||||
f"Edge existence check for "
|
||||
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}"
|
||||
)
|
||||
return edge_exists
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Failed to check edge existence {source_id} -> {target_id}: {error_msg}")
|
||||
|
|
@ -778,7 +794,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
results = await self.query(query, params)
|
||||
logger.debug(f"Found {len(results)} existing edges out of {len(edges)} checked")
|
||||
return [result["edge_exists"] for result in results]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_neptune_error(e)
|
||||
logger.error(f"Failed to check edges existence: {error_msg}")
|
||||
|
|
@ -863,10 +879,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
RETURN predecessor
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
{"node_id": node_id}
|
||||
)
|
||||
results = await self.query(query, {"node_id": node_id})
|
||||
|
||||
return [result["predecessor"] for result in results]
|
||||
|
||||
|
|
@ -893,14 +906,10 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
RETURN successor
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
{"node_id": node_id}
|
||||
)
|
||||
results = await self.query(query, {"node_id": node_id})
|
||||
|
||||
return [result["successor"] for result in results]
|
||||
|
||||
|
||||
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
||||
"""
|
||||
Get all neighboring nodes connected to the specified node.
|
||||
|
|
@ -926,11 +935,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
|
||||
# Format neighbors as NodeData objects
|
||||
neighbors = [
|
||||
{
|
||||
"id": neighbor["neighbor_id"],
|
||||
**neighbor["properties"]
|
||||
}
|
||||
for neighbor in result
|
||||
{"id": neighbor["neighbor_id"], **neighbor["properties"]} for neighbor in result
|
||||
]
|
||||
|
||||
logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}")
|
||||
|
|
@ -942,9 +947,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
raise Exception(f"Failed to get neighbors: {error_msg}") from e
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self,
|
||||
node_type: Type[Any],
|
||||
node_name: List[str]
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
"""
|
||||
Fetch a subgraph consisting of a specific set of nodes and their relationships.
|
||||
|
|
@ -987,10 +990,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
}}] AS rawRels
|
||||
"""
|
||||
|
||||
params = {
|
||||
"names": node_name,
|
||||
"type": node_type.__name__
|
||||
}
|
||||
params = {"names": node_name, "type": node_type.__name__}
|
||||
|
||||
result = await self.query(query, params)
|
||||
|
||||
|
|
@ -1002,18 +1002,14 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
raw_rels = result[0]["rawRels"]
|
||||
|
||||
# Format nodes as (node_id, properties) tuples
|
||||
nodes = [
|
||||
(n["id"], n["properties"])
|
||||
for n in raw_nodes
|
||||
]
|
||||
nodes = [(n["id"], n["properties"]) for n in raw_nodes]
|
||||
|
||||
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
||||
edges = [
|
||||
(r["source_id"], r["target_id"], r["type"], r["properties"])
|
||||
for r in raw_rels
|
||||
]
|
||||
edges = [(r["source_id"], r["target_id"], r["type"], r["properties"]) for r in raw_rels]
|
||||
|
||||
logger.debug(f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}")
|
||||
logger.debug(
|
||||
f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}"
|
||||
)
|
||||
return (nodes, edges)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -1055,18 +1051,12 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
# Return as (source_node, relationship, target_node)
|
||||
connections.append(
|
||||
(
|
||||
{
|
||||
"id": record["source_id"],
|
||||
**record["source_props"]
|
||||
},
|
||||
{"id": record["source_id"], **record["source_props"]},
|
||||
{
|
||||
"relationship_name": record["relationship_name"],
|
||||
**record["relationship_props"]
|
||||
**record["relationship_props"],
|
||||
},
|
||||
{
|
||||
"id": record["target_id"],
|
||||
**record["target_props"]
|
||||
}
|
||||
{"id": record["target_id"], **record["target_props"]},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1078,10 +1068,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
logger.error(f"Failed to get connections for node {node_id}: {error_msg}")
|
||||
raise Exception(f"Failed to get connections: {error_msg}") from e
|
||||
|
||||
|
||||
async def remove_connection_to_predecessors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
):
|
||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str):
|
||||
"""
|
||||
Remove connections (edges) to all predecessors of specified nodes based on edge label.
|
||||
|
||||
|
|
@ -1101,10 +1088,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
params = {"node_ids": node_ids}
|
||||
await self.query(query, params)
|
||||
|
||||
|
||||
async def remove_connection_to_successors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
):
|
||||
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str):
|
||||
"""
|
||||
Remove connections (edges) to all successors of specified nodes based on edge label.
|
||||
|
||||
|
|
@ -1124,7 +1108,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
params = {"node_ids": node_ids}
|
||||
await self.query(query, params)
|
||||
|
||||
|
||||
async def get_node_labels_string(self):
|
||||
"""
|
||||
Fetch all node labels from the database and return them as a formatted string.
|
||||
|
|
@ -1138,7 +1121,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
-------
|
||||
ValueError: If no node labels are found in the database.
|
||||
"""
|
||||
node_labels_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
|
||||
node_labels_query = (
|
||||
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
|
||||
)
|
||||
node_labels_result = await self.query(node_labels_query)
|
||||
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
|
||||
|
||||
|
|
@ -1156,7 +1141,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
|
||||
A formatted string of relationship types.
|
||||
"""
|
||||
relationship_types_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
|
||||
relationship_types_query = (
|
||||
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
|
||||
)
|
||||
relationship_types_result = await self.query(relationship_types_query)
|
||||
relationship_types = (
|
||||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||||
|
|
@ -1172,7 +1159,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
)
|
||||
return relationship_types_undirected_str
|
||||
|
||||
|
||||
async def drop_graph(self, graph_name="myGraph"):
|
||||
"""
|
||||
Drop an existing graph from the database based on its name.
|
||||
|
|
@ -1208,7 +1194,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def project_entire_graph(self, graph_name="myGraph"):
|
||||
"""
|
||||
Project all node labels and relationship types into an in-memory graph using GDS.
|
||||
|
|
@ -1280,7 +1265,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
|
||||
async def get_degree_one_nodes(self, node_type: str):
|
||||
"""
|
||||
Fetch nodes of a specified type that have exactly one connection.
|
||||
|
|
@ -1361,8 +1345,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
result = await self.query(query, {"content_hash": content_hash})
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
|
||||
async def _get_model_independent_graph_data(self):
|
||||
"""
|
||||
Retrieve the basic graph data without considering the model specifics, returning nodes
|
||||
|
|
@ -1380,8 +1362,8 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
RETURN nodeCount AS numVertices, count(r) AS numEdges
|
||||
"""
|
||||
query_response = await self.query(query_string)
|
||||
num_nodes = query_response[0].get('numVertices')
|
||||
num_edges = query_response[0].get('numEdges')
|
||||
num_nodes = query_response[0].get("numVertices")
|
||||
num_edges = query_response[0].get("numEdges")
|
||||
|
||||
return (num_nodes, num_edges)
|
||||
|
||||
|
|
@ -1437,4 +1419,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
|
||||
@staticmethod
|
||||
def _convert_relationship_to_edge(relationship: dict) -> EdgeData:
|
||||
return relationship["source_id"], relationship["target_id"], relationship["relationship_name"], relationship["properties"]
|
||||
return (
|
||||
relationship["source_id"],
|
||||
relationship["target_id"],
|
||||
relationship["relationship_name"],
|
||||
relationship["properties"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,106 +2,114 @@
|
|||
|
||||
This module defines custom exceptions for Neptune Analytics operations.
|
||||
"""
|
||||
|
||||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class NeptuneAnalyticsError(CogneeApiError):
|
||||
"""Base exception for Neptune Analytics operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Neptune Analytics error.",
|
||||
name: str = "NeptuneAnalyticsError",
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
|
||||
class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError):
|
||||
"""Exception raised when connection to Neptune Analytics fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Unable to connect to Neptune Analytics. Please check the endpoint and network connectivity.",
|
||||
name: str = "NeptuneAnalyticsConnectionError",
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NeptuneAnalyticsQueryError(NeptuneAnalyticsError):
|
||||
"""Exception raised when a query execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The query execution failed due to invalid syntax or semantic issues.",
|
||||
name: str = "NeptuneAnalyticsQueryError",
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError):
|
||||
"""Exception raised when authentication with Neptune Analytics fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Authentication with Neptune Analytics failed. Please verify your credentials.",
|
||||
name: str = "NeptuneAnalyticsAuthenticationError",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError):
|
||||
"""Exception raised when Neptune Analytics configuration is invalid."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Neptune Analytics configuration is invalid or incomplete. Please review your setup.",
|
||||
name: str = "NeptuneAnalyticsConfigurationError",
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError):
|
||||
"""Exception raised when a Neptune Analytics operation times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The operation timed out while communicating with Neptune Analytics.",
|
||||
name: str = "NeptuneAnalyticsTimeoutError",
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError):
|
||||
"""Exception raised when requests are throttled by Neptune Analytics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Request was throttled by Neptune Analytics due to exceeding rate limits.",
|
||||
name: str = "NeptuneAnalyticsThrottlingError",
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError):
|
||||
"""Exception raised when a Neptune Analytics resource is not found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The requested Neptune Analytics resource could not be found.",
|
||||
name: str = "NeptuneAnalyticsResourceNotFoundError",
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError):
|
||||
"""Exception raised when invalid parameters are provided to Neptune Analytics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "One or more parameters provided to Neptune Analytics are invalid or missing.",
|
||||
name: str = "NeptuneAnalyticsInvalidParameterError",
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,40 +16,42 @@ logger = get_logger("NeptuneUtils")
|
|||
def parse_neptune_url(url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse a Neptune Analytics URL to extract graph ID and region.
|
||||
|
||||
|
||||
Expected format: neptune-graph://<GRAPH_ID>?region=<REGION>
|
||||
or neptune-graph://<GRAPH_ID> (defaults to us-east-1)
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- url (str): The Neptune Analytics URL to parse
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Tuple[str, str]: A tuple containing (graph_id, region)
|
||||
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- ValueError: If the URL format is invalid
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
|
||||
if parsed.scheme != "neptune-graph":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'")
|
||||
|
||||
graph_id = parsed.hostname or parsed.path.lstrip('/')
|
||||
|
||||
graph_id = parsed.hostname or parsed.path.lstrip("/")
|
||||
if not graph_id:
|
||||
raise ValueError("Graph ID not found in URL")
|
||||
|
||||
|
||||
# Extract region from query parameters
|
||||
region = "us-east-1" # default region
|
||||
if parsed.query:
|
||||
query_params = dict(param.split('=') for param in parsed.query.split('&') if '=' in param)
|
||||
region = query_params.get('region', region)
|
||||
|
||||
query_params = dict(
|
||||
param.split("=") for param in parsed.query.split("&") if "=" in param
|
||||
)
|
||||
region = query_params.get("region", region)
|
||||
|
||||
return graph_id, region
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}")
|
||||
|
||||
|
|
@ -57,43 +59,43 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
|
|||
def validate_graph_id(graph_id: str) -> bool:
|
||||
"""
|
||||
Validate a Neptune Analytics graph ID format.
|
||||
|
||||
|
||||
Graph IDs should follow AWS naming conventions.
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- graph_id (str): The graph ID to validate
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- bool: True if the graph ID is valid, False otherwise
|
||||
"""
|
||||
if not graph_id:
|
||||
return False
|
||||
|
||||
|
||||
# Neptune Analytics graph IDs should be alphanumeric with hyphens
|
||||
# and between 1-63 characters
|
||||
pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$'
|
||||
pattern = r"^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$"
|
||||
return bool(re.match(pattern, graph_id))
|
||||
|
||||
|
||||
def validate_aws_region(region: str) -> bool:
|
||||
"""
|
||||
Validate an AWS region format.
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- region (str): The AWS region to validate
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- bool: True if the region format is valid, False otherwise
|
||||
"""
|
||||
if not region:
|
||||
return False
|
||||
|
||||
|
||||
# AWS regions follow the pattern: us-east-1, eu-west-1, etc.
|
||||
pattern = r'^[a-z]{2,3}-[a-z]+-\d+$'
|
||||
pattern = r"^[a-z]{2,3}-[a-z]+-\d+$"
|
||||
return bool(re.match(pattern, region))
|
||||
|
||||
|
||||
|
|
@ -103,11 +105,11 @@ def build_neptune_config(
|
|||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a configuration dictionary for Neptune Analytics connection.
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- graph_id (str): The Neptune Analytics graph identifier
|
||||
|
|
@ -116,11 +118,11 @@ def build_neptune_config(
|
|||
- aws_secret_access_key (Optional[str]): AWS secret access key
|
||||
- aws_session_token (Optional[str]): AWS session token for temporary credentials
|
||||
- **kwargs: Additional configuration parameters
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Dict[str, Any]: Configuration dictionary for Neptune Analytics
|
||||
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- ValueError: If required parameters are invalid
|
||||
|
|
@ -129,35 +131,35 @@ def build_neptune_config(
|
|||
"graph_id": graph_id,
|
||||
"service_name": "neptune-graph",
|
||||
}
|
||||
|
||||
|
||||
# Add AWS credentials if provided
|
||||
if region:
|
||||
config["region"] = region
|
||||
|
||||
if aws_access_key_id:
|
||||
config["aws_access_key_id"] = aws_access_key_id
|
||||
|
||||
|
||||
if aws_secret_access_key:
|
||||
config["aws_secret_access_key"] = aws_secret_access_key
|
||||
|
||||
|
||||
if aws_session_token:
|
||||
config["aws_session_token"] = aws_session_token
|
||||
|
||||
|
||||
# Add any additional configuration
|
||||
config.update(kwargs)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_neptune_endpoint_url(graph_id: str, region: str) -> str:
|
||||
"""
|
||||
Construct the Neptune Analytics endpoint URL for a given graph and region.
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- graph_id (str): The Neptune Analytics graph identifier
|
||||
- region (str): AWS region where the graph is located
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: The Neptune Analytics endpoint URL
|
||||
|
|
@ -168,17 +170,17 @@ def get_neptune_endpoint_url(graph_id: str, region: str) -> str:
|
|||
def format_neptune_error(error: Exception) -> str:
|
||||
"""
|
||||
Format Neptune Analytics specific errors for better readability.
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- error (Exception): The exception to format
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: Formatted error message
|
||||
"""
|
||||
error_msg = str(error)
|
||||
|
||||
|
||||
# Common Neptune Analytics error patterns and their user-friendly messages
|
||||
error_mappings = {
|
||||
"AccessDenied": "Access denied. Please check your AWS credentials and permissions.",
|
||||
|
|
@ -187,17 +189,18 @@ def format_neptune_error(error: Exception) -> str:
|
|||
"ThrottlingException": "Request was throttled. Please retry with exponential backoff.",
|
||||
"InternalServerError": "Internal server error occurred. Please try again later.",
|
||||
}
|
||||
|
||||
|
||||
for error_type, friendly_msg in error_mappings.items():
|
||||
if error_type in error_msg:
|
||||
return f"{friendly_msg} Original error: {error_msg}"
|
||||
|
||||
|
||||
return error_msg
|
||||
|
||||
|
||||
def get_default_query_timeout() -> int:
|
||||
"""
|
||||
Get the default query timeout for Neptune Analytics operations.
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- int: Default timeout in seconds
|
||||
|
|
@ -208,7 +211,7 @@ def get_default_query_timeout() -> int:
|
|||
def get_default_connection_config() -> Dict[str, Any]:
|
||||
"""
|
||||
Get default connection configuration for Neptune Analytics.
|
||||
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Dict[str, Any]: Default connection configuration
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredRes
|
|||
|
||||
logger = get_logger("NeptuneAnalyticsAdapter")
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
"""
|
||||
Represents a schema for an index data point containing an ID and text.
|
||||
|
|
@ -27,16 +28,19 @@ class IndexSchema(DataPoint):
|
|||
- metadata: A dictionary with default index fields for the schema, currently configured
|
||||
to include 'text'.
|
||||
"""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://"
|
||||
|
||||
|
||||
class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||
"""
|
||||
Hybrid adapter that combines Neptune Analytics Vector and Graph functionality.
|
||||
|
||||
|
||||
This adapter extends NeptuneGraphDB and implements VectorDBInterface to provide
|
||||
a unified interface for working with Neptune Analytics as both a vector store
|
||||
and a graph database.
|
||||
|
|
@ -48,14 +52,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
_TOPK_UPPER_BOUND = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_id: str,
|
||||
embedding_engine: Optional[EmbeddingEngine] = None,
|
||||
region: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
):
|
||||
self,
|
||||
graph_id: str,
|
||||
embedding_engine: Optional[EmbeddingEngine] = None,
|
||||
region: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Neptune Analytics hybrid adapter.
|
||||
|
||||
|
|
@ -74,12 +78,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
region=region,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
|
||||
|
||||
# Add vector-specific attributes
|
||||
self.embedding_engine = embedding_engine
|
||||
logger.info(f"Initialized Neptune Analytics hybrid adapter for graph: \"{graph_id}\" in region: \"{self.region}\"")
|
||||
logger.info(
|
||||
f'Initialized Neptune Analytics hybrid adapter for graph: "{graph_id}" in region: "{self.region}"'
|
||||
)
|
||||
|
||||
# VectorDBInterface methods implementation
|
||||
|
||||
|
|
@ -161,7 +167,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
|
||||
# Fetch embeddings
|
||||
texts = [DataPoint.get_embeddable_data(t) for t in data_points]
|
||||
data_vectors = (await self.embedding_engine.embed_text(texts))
|
||||
data_vectors = await self.embedding_engine.embed_text(texts)
|
||||
|
||||
for index, data_point in enumerate(data_points):
|
||||
node_id = data_point.id
|
||||
|
|
@ -172,23 +178,24 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
properties = self._serialize_properties(data_point.model_dump())
|
||||
properties[self._COLLECTION_PREFIX] = collection_name
|
||||
params = dict(
|
||||
node_id = str(node_id),
|
||||
properties = properties,
|
||||
embedding = data_vector,
|
||||
collection_name = collection_name
|
||||
node_id=str(node_id),
|
||||
properties=properties,
|
||||
embedding=data_vector,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
# Compose the query and send
|
||||
query_string = (
|
||||
f"MERGE (n "
|
||||
f":{self._VECTOR_NODE_LABEL} "
|
||||
f" {{`~id`: $node_id}}) "
|
||||
f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
|
||||
f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
|
||||
f"WITH n, $embedding AS embedding "
|
||||
f"CALL neptune.algo.vectors.upsert(n, embedding) "
|
||||
f"YIELD success "
|
||||
f"RETURN success ")
|
||||
f"MERGE (n "
|
||||
f":{self._VECTOR_NODE_LABEL} "
|
||||
f" {{`~id`: $node_id}}) "
|
||||
f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
|
||||
f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
|
||||
f"WITH n, $embedding AS embedding "
|
||||
f"CALL neptune.algo.vectors.upsert(n, embedding) "
|
||||
f"YIELD success "
|
||||
f"RETURN success "
|
||||
)
|
||||
|
||||
try:
|
||||
self._client.query(query_string, params)
|
||||
|
|
@ -208,10 +215,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
"""
|
||||
# Do the fetch for each node
|
||||
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
||||
query_string = (f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"WHERE id(n) in $node_ids AND "
|
||||
f"n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||
f"RETURN n as payload ")
|
||||
query_string = (
|
||||
f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"WHERE id(n) in $node_ids AND "
|
||||
f"n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||
f"RETURN n as payload "
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._client.query(query_string, params)
|
||||
|
|
@ -259,7 +268,8 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
||||
logger.warning(
|
||||
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
|
||||
"Defaulting to limit=10.", limit
|
||||
"Defaulting to limit=10.",
|
||||
limit,
|
||||
)
|
||||
limit = self._TOPK_UPPER_BOUND
|
||||
|
||||
|
|
@ -272,7 +282,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
elif query_vector:
|
||||
embedding = query_vector
|
||||
else:
|
||||
data_vectors = (await self.embedding_engine.embed_text([query_text]))
|
||||
data_vectors = await self.embedding_engine.embed_text([query_text])
|
||||
embedding = data_vectors[0]
|
||||
|
||||
# Compose the parameters map
|
||||
|
|
@ -305,9 +315,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
|
||||
try:
|
||||
query_response = self._client.query(query_string, params)
|
||||
return [self._get_scored_result(
|
||||
item = item, with_score = True
|
||||
) for item in query_response]
|
||||
return [self._get_scored_result(item=item, with_score=True) for item in query_response]
|
||||
except Exception as e:
|
||||
self._na_exception_handler(e, query_string)
|
||||
|
||||
|
|
@ -332,11 +340,13 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
self._validate_embedding_engine()
|
||||
|
||||
# Convert text to embedding array in batch
|
||||
data_vectors = (await self.embedding_engine.embed_text(query_texts))
|
||||
return await asyncio.gather(*[
|
||||
self.search(collection_name, None, vector, limit, with_vectors)
|
||||
for vector in data_vectors
|
||||
])
|
||||
data_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.search(collection_name, None, vector, limit, with_vectors)
|
||||
for vector in data_vectors
|
||||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
"""
|
||||
|
|
@ -350,10 +360,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
- data_point_ids (list[str]): A list of IDs of the data points to delete.
|
||||
"""
|
||||
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
||||
query_string = (f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"WHERE id(n) IN $node_ids "
|
||||
f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||
f"DETACH DELETE n")
|
||||
query_string = (
|
||||
f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"WHERE id(n) IN $node_ids "
|
||||
f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||
f"DETACH DELETE n"
|
||||
)
|
||||
try:
|
||||
self._client.query(query_string, params)
|
||||
except Exception as e:
|
||||
|
|
@ -370,7 +382,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
"""
|
||||
Indexes a list of data points into Neptune Analytics by creating them as nodes.
|
||||
|
|
@ -402,29 +414,28 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
Remove obsolete or unnecessary data from the database.
|
||||
"""
|
||||
# Run actual truncate
|
||||
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"DETACH DELETE n")
|
||||
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n")
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_scored_result(item: dict, with_vector: bool = False, with_score: bool = False) -> ScoredResult:
|
||||
def _get_scored_result(
|
||||
item: dict, with_vector: bool = False, with_score: bool = False
|
||||
) -> ScoredResult:
|
||||
"""
|
||||
Util method to simplify the object creation of ScoredResult base on incoming NX payload response.
|
||||
"""
|
||||
return ScoredResult(
|
||||
id=item.get('payload').get('~id'),
|
||||
payload=item.get('payload').get('~properties'),
|
||||
score=item.get('score') if with_score else 0,
|
||||
vector=item.get('embedding') if with_vector else None
|
||||
id=item.get("payload").get("~id"),
|
||||
payload=item.get("payload").get("~properties"),
|
||||
score=item.get("score") if with_score else 0,
|
||||
vector=item.get("embedding") if with_vector else None,
|
||||
)
|
||||
|
||||
def _na_exception_handler(self, ex, query_string: str):
|
||||
"""
|
||||
Generic exception handler for NA langchain.
|
||||
"""
|
||||
logger.error(
|
||||
"Neptune Analytics query failed: %s | Query: [%s]", ex, query_string
|
||||
)
|
||||
logger.error("Neptune Analytics query failed: %s | Query: [%s]", ex, query_string)
|
||||
raise ex
|
||||
|
||||
def _validate_embedding_engine(self):
|
||||
|
|
@ -433,4 +444,6 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
:raises: ValueError if this object does not have a valid embedding_engine
|
||||
"""
|
||||
if self.embedding_engine is None:
|
||||
raise ValueError("Neptune Analytics requires an embedder defined to make vector operations")
|
||||
raise ValueError(
|
||||
"Neptune Analytics requires an embedder defined to make vector operations"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -125,9 +125,15 @@ def create_vector_engine(
|
|||
if not vector_db_url:
|
||||
raise EnvironmentError("Missing Neptune endpoint.")
|
||||
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||
NeptuneAnalyticsAdapter,
|
||||
NEPTUNE_ANALYTICS_ENDPOINT_URL,
|
||||
)
|
||||
|
||||
if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
||||
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
|
||||
raise ValueError(
|
||||
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
|
||||
)
|
||||
|
||||
graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.modules.data.processing.document_types import TextDocument
|
|||
|
||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||
load_dotenv()
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
graph_id = os.getenv("GRAPH_ID", "")
|
||||
|
||||
na_adapter = NeptuneGraphDB(graph_id)
|
||||
|
||||
|
|
@ -24,33 +24,33 @@ def setup():
|
|||
# stored in Amazon S3.
|
||||
|
||||
document = TextDocument(
|
||||
name='text_test.txt',
|
||||
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text_test.txt',
|
||||
external_metadata='{}',
|
||||
mime_type='text/plain'
|
||||
name="text_test.txt",
|
||||
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt",
|
||||
external_metadata="{}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
document_chunk = DocumentChunk(
|
||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
chunk_size=187,
|
||||
chunk_index=0,
|
||||
cut_type='paragraph_end',
|
||||
cut_type="paragraph_end",
|
||||
is_part_of=document,
|
||||
)
|
||||
|
||||
graph_database = EntityType(name='graph database', description='graph database')
|
||||
graph_database = EntityType(name="graph database", description="graph database")
|
||||
neptune_analytics_entity = Entity(
|
||||
name='neptune analytics',
|
||||
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
|
||||
name="neptune analytics",
|
||||
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
||||
)
|
||||
neptune_database_entity = Entity(
|
||||
name='amazon neptune database',
|
||||
description='A popular managed graph database that complements Neptune Analytics.',
|
||||
name="amazon neptune database",
|
||||
description="A popular managed graph database that complements Neptune Analytics.",
|
||||
)
|
||||
|
||||
storage = EntityType(name='storage', description='storage')
|
||||
storage = EntityType(name="storage", description="storage")
|
||||
storage_entity = Entity(
|
||||
name='amazon s3',
|
||||
description='A storage service provided by Amazon Web Services that allows storing graph data.',
|
||||
name="amazon s3",
|
||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||
)
|
||||
|
||||
nodes_data = [
|
||||
|
|
@ -67,37 +67,37 @@ def setup():
|
|||
(
|
||||
str(document_chunk.id),
|
||||
str(storage_entity.id),
|
||||
'contains',
|
||||
"contains",
|
||||
),
|
||||
(
|
||||
str(storage_entity.id),
|
||||
str(storage.id),
|
||||
'is_a',
|
||||
"is_a",
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_database_entity.id),
|
||||
'contains',
|
||||
"contains",
|
||||
),
|
||||
(
|
||||
str(neptune_database_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
"is_a",
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(document.id),
|
||||
'is_part_of',
|
||||
"is_part_of",
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_analytics_entity.id),
|
||||
'contains',
|
||||
"contains",
|
||||
),
|
||||
(
|
||||
str(neptune_analytics_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
"is_a",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -155,42 +155,44 @@ async def pipeline_method():
|
|||
print("------NEIGHBORING NODES-------")
|
||||
center_node = nodes[2]
|
||||
neighbors = await na_adapter.get_neighbors(str(center_node.id))
|
||||
print(f"found {len(neighbors)} neighbors for node \"{center_node.name}\"")
|
||||
print(f'found {len(neighbors)} neighbors for node "{center_node.name}"')
|
||||
for neighbor in neighbors:
|
||||
print(neighbor)
|
||||
|
||||
print("------NEIGHBORING EDGES-------")
|
||||
center_node = nodes[2]
|
||||
neighbouring_edges = await na_adapter.get_edges(str(center_node.id))
|
||||
print(f"found {len(neighbouring_edges)} edges neighbouring node \"{center_node.name}\"")
|
||||
print(f'found {len(neighbouring_edges)} edges neighbouring node "{center_node.name}"')
|
||||
for edge in neighbouring_edges:
|
||||
print(edge)
|
||||
|
||||
print("------GET CONNECTIONS (SOURCE NODE)-------")
|
||||
document_chunk_node = nodes[0]
|
||||
connections = await na_adapter.get_connections(str(document_chunk_node.id))
|
||||
print(f"found {len(connections)} connections for node \"{document_chunk_node.type}\"")
|
||||
print(f'found {len(connections)} connections for node "{document_chunk_node.type}"')
|
||||
for connection in connections:
|
||||
src, relationship, tgt = connection
|
||||
src = src.get("name", src.get("type", "unknown"))
|
||||
relationship = relationship["relationship_name"]
|
||||
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
||||
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
|
||||
print(f'"{src}"-[{relationship}]->"{tgt}"')
|
||||
|
||||
print("------GET CONNECTIONS (TARGET NODE)-------")
|
||||
connections = await na_adapter.get_connections(str(center_node.id))
|
||||
print(f"found {len(connections)} connections for node \"{center_node.name}\"")
|
||||
print(f'found {len(connections)} connections for node "{center_node.name}"')
|
||||
for connection in connections:
|
||||
src, relationship, tgt = connection
|
||||
src = src.get("name", src.get("type", "unknown"))
|
||||
relationship = relationship["relationship_name"]
|
||||
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
||||
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
|
||||
print(f'"{src}"-[{relationship}]->"{tgt}"')
|
||||
|
||||
print("------SUBGRAPH-------")
|
||||
node_names = ["neptune analytics", "amazon neptune database"]
|
||||
subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names)
|
||||
print(f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}")
|
||||
print(
|
||||
f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}"
|
||||
)
|
||||
for subgraph_node in subgraph_nodes:
|
||||
print(subgraph_node)
|
||||
for subgraph_edge in subgraph_edges:
|
||||
|
|
@ -199,17 +201,17 @@ async def pipeline_method():
|
|||
print("------STAT-------")
|
||||
stat = await na_adapter.get_graph_metrics(include_optional=True)
|
||||
assert type(stat) is dict
|
||||
assert stat['num_nodes'] == 7
|
||||
assert stat['num_edges'] == 7
|
||||
assert stat['mean_degree'] == 2.0
|
||||
assert round(stat['edge_density'], 3) == 0.167
|
||||
assert stat['num_connected_components'] == [7]
|
||||
assert stat['sizes_of_connected_components'] == 1
|
||||
assert stat['num_selfloops'] == 0
|
||||
assert stat["num_nodes"] == 7
|
||||
assert stat["num_edges"] == 7
|
||||
assert stat["mean_degree"] == 2.0
|
||||
assert round(stat["edge_density"], 3) == 0.167
|
||||
assert stat["num_connected_components"] == [7]
|
||||
assert stat["sizes_of_connected_components"] == 1
|
||||
assert stat["num_selfloops"] == 0
|
||||
# Unsupported optional metrics
|
||||
assert stat['diameter'] == -1
|
||||
assert stat['avg_shortest_path_length'] == -1
|
||||
assert stat['avg_clustering'] == -1
|
||||
assert stat["diameter"] == -1
|
||||
assert stat["avg_shortest_path_length"] == -1
|
||||
assert stat["avg_clustering"] == -1
|
||||
|
||||
print("------DELETE-------")
|
||||
# delete all nodes and edges:
|
||||
|
|
@ -253,7 +255,9 @@ async def misc_methods():
|
|||
print(edge_labels)
|
||||
|
||||
print("------Get Filtered Graph-------")
|
||||
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data([{'name': ['text_test.txt']}])
|
||||
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data(
|
||||
[{"name": ["text_test.txt"]}]
|
||||
)
|
||||
print(filtered_nodes, filtered_edges)
|
||||
|
||||
print("------Get Degree one nodes-------")
|
||||
|
|
@ -261,15 +265,13 @@ async def misc_methods():
|
|||
print(degree_one_nodes)
|
||||
|
||||
print("------Get Doc sub-graph-------")
|
||||
doc_sub_graph = await na_adapter.get_document_subgraph('test.txt')
|
||||
doc_sub_graph = await na_adapter.get_document_subgraph("test.txt")
|
||||
print(doc_sub_graph)
|
||||
|
||||
print("------Fetch and Remove connections (Predecessors)-------")
|
||||
# Fetch test edge
|
||||
(src_id, dest_id, relationship) = edges[0]
|
||||
nodes_predecessors = await na_adapter.get_predecessors(
|
||||
node_id=dest_id, edge_label=relationship
|
||||
)
|
||||
nodes_predecessors = await na_adapter.get_predecessors(node_id=dest_id, edge_label=relationship)
|
||||
assert len(nodes_predecessors) > 0
|
||||
|
||||
await na_adapter.remove_connection_to_predecessors_of(
|
||||
|
|
@ -281,25 +283,19 @@ async def misc_methods():
|
|||
# Return empty after relationship being deleted.
|
||||
assert len(nodes_predecessors_after) == 0
|
||||
|
||||
|
||||
print("------Fetch and Remove connections (Successors)-------")
|
||||
_, edges_suc = await na_adapter.get_graph_data()
|
||||
(src_id, dest_id, relationship, _) = edges_suc[0]
|
||||
|
||||
nodes_successors = await na_adapter.get_successors(
|
||||
node_id=src_id, edge_label=relationship
|
||||
)
|
||||
nodes_successors = await na_adapter.get_successors(node_id=src_id, edge_label=relationship)
|
||||
assert len(nodes_successors) > 0
|
||||
|
||||
await na_adapter.remove_connection_to_successors_of(
|
||||
node_ids=[dest_id], edge_label=relationship
|
||||
)
|
||||
await na_adapter.remove_connection_to_successors_of(node_ids=[dest_id], edge_label=relationship)
|
||||
nodes_successors_after = await na_adapter.get_successors(
|
||||
node_id=src_id, edge_label=relationship
|
||||
)
|
||||
assert len(nodes_successors_after) == 0
|
||||
|
||||
|
||||
# no-op
|
||||
await na_adapter.project_entire_graph()
|
||||
await na_adapter.drop_graph()
|
||||
|
|
|
|||
|
|
@ -8,11 +8,13 @@ from cognee.modules.engine.models import Entity, EntityType
|
|||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||
NeptuneAnalyticsAdapter,
|
||||
)
|
||||
|
||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||
load_dotenv()
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
graph_id = os.getenv("GRAPH_ID", "")
|
||||
|
||||
# get the default embedder
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
|
@ -23,6 +25,7 @@ collection = "test_collection"
|
|||
|
||||
logger = get_logger("test_neptune_analytics_hybrid")
|
||||
|
||||
|
||||
def setup_data():
|
||||
# Define nodes data before the main function
|
||||
# These nodes were defined using openAI from the following prompt:
|
||||
|
|
@ -34,35 +37,35 @@ def setup_data():
|
|||
# stored in Amazon S3.
|
||||
|
||||
document = TextDocument(
|
||||
name='text.txt',
|
||||
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text.txt',
|
||||
external_metadata='{}',
|
||||
mime_type='text/plain'
|
||||
name="text.txt",
|
||||
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text.txt",
|
||||
external_metadata="{}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
document_chunk = DocumentChunk(
|
||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
chunk_size=187,
|
||||
chunk_index=0,
|
||||
cut_type='paragraph_end',
|
||||
cut_type="paragraph_end",
|
||||
is_part_of=document,
|
||||
)
|
||||
|
||||
graph_database = EntityType(name='graph database', description='graph database')
|
||||
graph_database = EntityType(name="graph database", description="graph database")
|
||||
neptune_analytics_entity = Entity(
|
||||
name='neptune analytics',
|
||||
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
|
||||
name="neptune analytics",
|
||||
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
||||
)
|
||||
neptune_database_entity = Entity(
|
||||
name='amazon neptune database',
|
||||
description='A popular managed graph database that complements Neptune Analytics.',
|
||||
name="amazon neptune database",
|
||||
description="A popular managed graph database that complements Neptune Analytics.",
|
||||
)
|
||||
|
||||
storage = EntityType(name='storage', description='storage')
|
||||
storage = EntityType(name="storage", description="storage")
|
||||
storage_entity = Entity(
|
||||
name='amazon s3',
|
||||
description='A storage service provided by Amazon Web Services that allows storing graph data.',
|
||||
name="amazon s3",
|
||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||
)
|
||||
|
||||
|
||||
nodes_data = [
|
||||
document,
|
||||
document_chunk,
|
||||
|
|
@ -77,41 +80,42 @@ def setup_data():
|
|||
(
|
||||
str(document_chunk.id),
|
||||
str(storage_entity.id),
|
||||
'contains',
|
||||
"contains",
|
||||
),
|
||||
(
|
||||
str(storage_entity.id),
|
||||
str(storage.id),
|
||||
'is_a',
|
||||
"is_a",
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_database_entity.id),
|
||||
'contains',
|
||||
"contains",
|
||||
),
|
||||
(
|
||||
str(neptune_database_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
"is_a",
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(document.id),
|
||||
'is_part_of',
|
||||
"is_part_of",
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_analytics_entity.id),
|
||||
'contains',
|
||||
"contains",
|
||||
),
|
||||
(
|
||||
str(neptune_analytics_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
"is_a",
|
||||
),
|
||||
]
|
||||
return nodes_data, edges_data
|
||||
|
||||
|
||||
async def test_add_graph_then_vector_data():
|
||||
logger.info("------test_add_graph_then_vector_data-------")
|
||||
(nodes, edges) = setup_data()
|
||||
|
|
@ -134,6 +138,7 @@ async def test_add_graph_then_vector_data():
|
|||
assert len(edges) == 0
|
||||
logger.info("------PASSED-------")
|
||||
|
||||
|
||||
async def test_add_vector_then_node_data():
|
||||
logger.info("------test_add_vector_then_node_data-------")
|
||||
(nodes, edges) = setup_data()
|
||||
|
|
@ -156,6 +161,7 @@ async def test_add_vector_then_node_data():
|
|||
assert len(edges) == 0
|
||||
logger.info("------PASSED-------")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data
|
||||
|
|
@ -165,5 +171,6 @@ def main():
|
|||
asyncio.run(test_add_graph_then_vector_data())
|
||||
asyncio.run(test_add_vector_then_node_data())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -8,13 +8,16 @@ from cognee.modules.users.methods import get_default_user
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, IndexSchema
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||
NeptuneAnalyticsAdapter,
|
||||
IndexSchema,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
graph_id = os.getenv("GRAPH_ID", "")
|
||||
cognee.config.set_vector_db_provider("neptune_analytics")
|
||||
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
||||
data_directory_path = str(
|
||||
|
|
@ -87,6 +90,7 @@ async def main():
|
|||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
||||
async def vector_backend_api_test():
|
||||
cognee.config.set_vector_db_provider("neptune_analytics")
|
||||
|
||||
|
|
@ -101,7 +105,7 @@ async def vector_backend_api_test():
|
|||
get_vector_engine()
|
||||
|
||||
# Return a valid engine object with valid URL.
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
graph_id = os.getenv("GRAPH_ID", "")
|
||||
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
||||
engine = get_vector_engine()
|
||||
assert isinstance(engine, NeptuneAnalyticsAdapter)
|
||||
|
|
@ -133,28 +137,22 @@ async def vector_backend_api_test():
|
|||
query_text=TEST_TEXT,
|
||||
query_vector=None,
|
||||
limit=10,
|
||||
with_vector=True)
|
||||
assert (len(result_search) == 2)
|
||||
with_vector=True,
|
||||
)
|
||||
assert len(result_search) == 2
|
||||
|
||||
# # Retrieve data-points
|
||||
result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
||||
assert any(
|
||||
str(r.id) == TEST_UUID and r.payload['text'] == TEST_TEXT
|
||||
for r in result
|
||||
)
|
||||
assert any(
|
||||
str(r.id) == TEST_UUID_2 and r.payload['text'] == TEST_TEXT_2
|
||||
for r in result
|
||||
)
|
||||
assert any(str(r.id) == TEST_UUID and r.payload["text"] == TEST_TEXT for r in result)
|
||||
assert any(str(r.id) == TEST_UUID_2 and r.payload["text"] == TEST_TEXT_2 for r in result)
|
||||
# Search multiple
|
||||
result_search_batch = await engine.batch_search(
|
||||
collection_name=TEST_COLLECTION_NAME,
|
||||
query_texts=[TEST_TEXT, TEST_TEXT_2],
|
||||
limit=10,
|
||||
with_vectors=False
|
||||
with_vectors=False,
|
||||
)
|
||||
assert (len(result_search_batch) == 2 and
|
||||
all(len(batch) == 2 for batch in result_search_batch))
|
||||
assert len(result_search_batch) == 2 and all(len(batch) == 2 for batch in result_search_batch)
|
||||
|
||||
# Delete datapoint from vector store
|
||||
await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
||||
|
|
@ -163,6 +161,7 @@ async def vector_backend_api_test():
|
|||
result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID])
|
||||
assert result_deleted == []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Example script demonstrating how to use Cognee with Amazon Neptune Analytics
|
||||
|
|
@ -22,7 +23,7 @@ async def main():
|
|||
"""
|
||||
|
||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||
graph_endpoint_url = "neptune-graph://" + os.getenv('GRAPH_ID', "")
|
||||
graph_endpoint_url = "neptune-graph://" + os.getenv("GRAPH_ID", "")
|
||||
|
||||
# Configure Neptune Analytics as the graph & vector database provider
|
||||
cognee.config.set_graph_db_config(
|
||||
|
|
@ -77,7 +78,9 @@ async def main():
|
|||
|
||||
# Now let's perform some searches
|
||||
# 1. Search for insights related to "Neptune Analytics"
|
||||
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neptune Analytics")
|
||||
insights_results = await cognee.search(
|
||||
query_type=SearchType.INSIGHTS, query_text="Neptune Analytics"
|
||||
)
|
||||
print("\n========Insights about Neptune Analytics========:")
|
||||
for result in insights_results:
|
||||
print(f"- {result}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue