added linting and formatting
This commit is contained in:
parent
ae8fc7f0c9
commit
c0d60eef25
11 changed files with 355 additions and 326 deletions
|
|
@ -149,7 +149,9 @@ def create_graph_engine(
|
||||||
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
|
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
|
||||||
|
|
||||||
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
|
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
|
||||||
raise ValueError(f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>")
|
raise ValueError(
|
||||||
|
f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>"
|
||||||
|
)
|
||||||
|
|
||||||
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
|
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
|
||||||
|
|
||||||
|
|
@ -174,10 +176,15 @@ def create_graph_engine(
|
||||||
if not graph_database_url:
|
if not graph_database_url:
|
||||||
raise EnvironmentError("Missing Neptune endpoint.")
|
raise EnvironmentError("Missing Neptune endpoint.")
|
||||||
|
|
||||||
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
|
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||||
|
NeptuneAnalyticsAdapter,
|
||||||
|
NEPTUNE_ANALYTICS_ENDPOINT_URL,
|
||||||
|
)
|
||||||
|
|
||||||
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
||||||
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
|
raise ValueError(
|
||||||
|
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
|
||||||
|
)
|
||||||
|
|
||||||
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ logger = get_logger("NeptuneGraphDB")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_aws import NeptuneAnalyticsGraph
|
from langchain_aws import NeptuneAnalyticsGraph
|
||||||
|
|
||||||
LANGCHAIN_AWS_AVAILABLE = True
|
LANGCHAIN_AWS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.")
|
logger.warning("langchain_aws not available. Neptune Analytics functionality will be limited.")
|
||||||
|
|
@ -36,11 +37,13 @@ except ImportError:
|
||||||
|
|
||||||
NEPTUNE_ENDPOINT_URL = "neptune-graph://"
|
NEPTUNE_ENDPOINT_URL = "neptune-graph://"
|
||||||
|
|
||||||
|
|
||||||
class NeptuneGraphDB(GraphDBInterface):
|
class NeptuneGraphDB(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
Adapter for interacting with Amazon Neptune Analytics graph store.
|
Adapter for interacting with Amazon Neptune Analytics graph store.
|
||||||
This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library.
|
This class provides methods for querying, adding, deleting nodes and edges using the aws_langchain library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_GRAPH_NODE_LABEL = "COGNEE_NODE"
|
_GRAPH_NODE_LABEL = "COGNEE_NODE"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -68,14 +71,16 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
# validate import
|
# validate import
|
||||||
if not LANGCHAIN_AWS_AVAILABLE:
|
if not LANGCHAIN_AWS_AVAILABLE:
|
||||||
raise ImportError("langchain_aws is not available. Please install it to use Neptune Analytics.")
|
raise ImportError(
|
||||||
|
"langchain_aws is not available. Please install it to use Neptune Analytics."
|
||||||
|
)
|
||||||
|
|
||||||
# Validate configuration
|
# Validate configuration
|
||||||
if not validate_graph_id(graph_id):
|
if not validate_graph_id(graph_id):
|
||||||
raise NeptuneAnalyticsConfigurationError(message=f"Invalid graph ID: \"{graph_id}\"")
|
raise NeptuneAnalyticsConfigurationError(message=f'Invalid graph ID: "{graph_id}"')
|
||||||
|
|
||||||
if region and not validate_aws_region(region):
|
if region and not validate_aws_region(region):
|
||||||
raise NeptuneAnalyticsConfigurationError(message=f"Invalid AWS region: \"{region}\"")
|
raise NeptuneAnalyticsConfigurationError(message=f'Invalid AWS region: "{region}"')
|
||||||
|
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
self.region = region
|
self.region = region
|
||||||
|
|
@ -94,7 +99,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
|
|
||||||
# Initialize Neptune Analytics client using langchain_aws
|
# Initialize Neptune Analytics client using langchain_aws
|
||||||
self._client: NeptuneAnalyticsGraph = self._initialize_client()
|
self._client: NeptuneAnalyticsGraph = self._initialize_client()
|
||||||
logger.info(f"Initialized Neptune Analytics adapter for graph: \"{graph_id}\" in region: \"{self.region}\"")
|
logger.info(
|
||||||
|
f'Initialized Neptune Analytics adapter for graph: "{graph_id}" in region: "{self.region}"'
|
||||||
|
)
|
||||||
|
|
||||||
def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]:
|
def _initialize_client(self) -> Optional[NeptuneAnalyticsGraph]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -108,9 +115,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
# Initialize the Neptune Analytics Graph client
|
# Initialize the Neptune Analytics Graph client
|
||||||
client_config = {
|
client_config = {
|
||||||
"graph_identifier": self.graph_id,
|
"graph_identifier": self.graph_id,
|
||||||
"config": Config(
|
"config": Config(user_agent_appid="Cognee"),
|
||||||
user_agent_appid='Cognee'
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
# Add AWS credentials if provided
|
# Add AWS credentials if provided
|
||||||
if self.region:
|
if self.region:
|
||||||
|
|
@ -127,7 +132,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
return client
|
return client
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise NeptuneAnalyticsConfigurationError(message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}") from e
|
raise NeptuneAnalyticsConfigurationError(
|
||||||
|
message=f"Failed to initialize Neptune Analytics client: {format_neptune_error(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]:
|
def _serialize_properties(properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
@ -215,6 +222,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
|
|
||||||
result = await self.query(query, params)
|
result = await self.query(query, params)
|
||||||
logger.debug(f"Successfully added/updated node: {node.id}")
|
logger.debug(f"Successfully added/updated node: {node.id}")
|
||||||
|
logger.debug(f"Successfully gotten: {str(result)}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = format_neptune_error(e)
|
error_msg = format_neptune_error(e)
|
||||||
|
|
@ -256,7 +264,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
}
|
}
|
||||||
result = await self.query(query, params)
|
result = await self.query(query, params)
|
||||||
|
|
||||||
processed_count = result[0].get('nodes_processed', 0) if result else 0
|
processed_count = result[0].get("nodes_processed", 0) if result else 0
|
||||||
logger.debug(f"Successfully processed {processed_count} nodes in bulk operation")
|
logger.debug(f"Successfully processed {processed_count} nodes in bulk operation")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -268,7 +276,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
try:
|
try:
|
||||||
await self.add_node(node)
|
await self.add_node(node)
|
||||||
except Exception as node_error:
|
except Exception as node_error:
|
||||||
logger.error(f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}")
|
logger.error(
|
||||||
|
f"Failed to add individual node {node.id}: {format_neptune_error(node_error)}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
|
@ -287,9 +297,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {"node_id": node_id}
|
||||||
"node_id": node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.query(query, params)
|
await self.query(query, params)
|
||||||
logger.debug(f"Successfully deleted node: {node_id}")
|
logger.debug(f"Successfully deleted node: {node_id}")
|
||||||
|
|
@ -333,7 +341,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
try:
|
try:
|
||||||
await self.delete_node(node_id)
|
await self.delete_node(node_id)
|
||||||
except Exception as node_error:
|
except Exception as node_error:
|
||||||
logger.error(f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}")
|
logger.error(
|
||||||
|
f"Failed to delete individual node {node_id}: {format_neptune_error(node_error)}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
||||||
|
|
@ -355,7 +365,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
WHERE id(n) = $node_id
|
WHERE id(n) = $node_id
|
||||||
RETURN n
|
RETURN n
|
||||||
"""
|
"""
|
||||||
params = {'node_id': node_id}
|
params = {"node_id": node_id}
|
||||||
|
|
||||||
result = await self.query(query, params)
|
result = await self.query(query, params)
|
||||||
|
|
||||||
|
|
@ -406,7 +416,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
# Extract node data from results
|
# Extract node data from results
|
||||||
nodes = [record["n"] for record in result]
|
nodes = [record["n"] for record in result]
|
||||||
|
|
||||||
logger.debug(f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested")
|
logger.debug(
|
||||||
|
f"Successfully retrieved {len(nodes)} nodes out of {len(node_ids)} requested"
|
||||||
|
)
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -421,11 +433,12 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
if node_data:
|
if node_data:
|
||||||
nodes.append(node_data)
|
nodes.append(node_data)
|
||||||
except Exception as node_error:
|
except Exception as node_error:
|
||||||
logger.error(f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}")
|
logger.error(
|
||||||
|
f"Failed to get individual node {node_id}: {format_neptune_error(node_error)}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
async def extract_node(self, node_id: str):
|
async def extract_node(self, node_id: str):
|
||||||
"""
|
"""
|
||||||
Retrieve a single node based on its ID.
|
Retrieve a single node based on its ID.
|
||||||
|
|
@ -512,7 +525,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
"properties": serialized_properties,
|
"properties": serialized_properties,
|
||||||
}
|
}
|
||||||
await self.query(query, params)
|
await self.query(query, params)
|
||||||
logger.debug(f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}")
|
logger.debug(
|
||||||
|
f"Successfully added edge: {source_id} -[{relationship_name}]-> {target_id}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = format_neptune_error(e)
|
error_msg = format_neptune_error(e)
|
||||||
|
|
@ -557,20 +572,24 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prepare edges data for bulk operation
|
# Prepare edges data for bulk operation
|
||||||
params = {"edges":
|
params = {
|
||||||
[
|
"edges": [
|
||||||
{
|
{
|
||||||
"from_node": str(edge[0]),
|
"from_node": str(edge[0]),
|
||||||
"to_node": str(edge[1]),
|
"to_node": str(edge[1]),
|
||||||
"relationship_name": relationship_name,
|
"relationship_name": relationship_name,
|
||||||
"properties": self._serialize_properties(edge[3] if len(edge) > 3 and edge[3] else {}),
|
"properties": self._serialize_properties(
|
||||||
|
edge[3] if len(edge) > 3 and edge[3] else {}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for edge in edges_for_relationship
|
for edge in edges_for_relationship
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
results[relationship_name] = await self.query(query, params)
|
results[relationship_name] = await self.query(query, params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}")
|
logger.error(
|
||||||
|
f"Failed to add edges for relationship {relationship_name}: {format_neptune_error(e)}"
|
||||||
|
)
|
||||||
logger.info("Falling back to individual edge creation")
|
logger.info("Falling back to individual edge creation")
|
||||||
for edge in edges_by_relationship:
|
for edge in edges_by_relationship:
|
||||||
try:
|
try:
|
||||||
|
|
@ -578,15 +597,16 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
properties = edge[3] if len(edge) > 3 else {}
|
properties = edge[3] if len(edge) > 3 else {}
|
||||||
await self.add_edge(source_id, target_id, relationship_name, properties)
|
await self.add_edge(source_id, target_id, relationship_name, properties)
|
||||||
except Exception as edge_error:
|
except Exception as edge_error:
|
||||||
logger.error(f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}")
|
logger.error(
|
||||||
|
f"Failed to add individual edge {edge[0]} -> {edge[1]}: {format_neptune_error(edge_error)}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
for result in results.values():
|
for result in results.values():
|
||||||
processed_count += result[0].get('edges_processed', 0) if result else 0
|
processed_count += result[0].get("edges_processed", 0) if result else 0
|
||||||
logger.debug(f"Successfully processed {processed_count} edges in bulk operation")
|
logger.debug(f"Successfully processed {processed_count} edges in bulk operation")
|
||||||
|
|
||||||
|
|
||||||
async def delete_graph(self) -> None:
|
async def delete_graph(self) -> None:
|
||||||
"""
|
"""
|
||||||
Delete all nodes and edges from the graph database.
|
Delete all nodes and edges from the graph database.
|
||||||
|
|
@ -599,14 +619,12 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
# Build openCypher query to delete the graph
|
# Build openCypher query to delete the graph
|
||||||
query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n"
|
query = f"MATCH (n:{self._GRAPH_NODE_LABEL}) DETACH DELETE n"
|
||||||
await self.query(query)
|
await self.query(query)
|
||||||
logger.debug(f"Successfully deleted all edges and nodes from the graph")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = format_neptune_error(e)
|
error_msg = format_neptune_error(e)
|
||||||
logger.error(f"Failed to delete graph: {error_msg}")
|
logger.error(f"Failed to delete graph: {error_msg}")
|
||||||
raise Exception(f"Failed to delete graph: {error_msg}") from e
|
raise Exception(f"Failed to delete graph: {error_msg}") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
||||||
"""
|
"""
|
||||||
Retrieve all nodes and edges within the graph.
|
Retrieve all nodes and edges within the graph.
|
||||||
|
|
@ -633,13 +651,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
edges_result = await self.query(edges_query)
|
edges_result = await self.query(edges_query)
|
||||||
|
|
||||||
# Format nodes as (node_id, properties) tuples
|
# Format nodes as (node_id, properties) tuples
|
||||||
nodes = [
|
nodes = [(result["node_id"], result["properties"]) for result in nodes_result]
|
||||||
(
|
|
||||||
result["node_id"],
|
|
||||||
result["properties"]
|
|
||||||
)
|
|
||||||
for result in nodes_result
|
|
||||||
]
|
|
||||||
|
|
||||||
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
||||||
edges = [
|
edges = [
|
||||||
|
|
@ -647,7 +659,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
result["source_id"],
|
result["source_id"],
|
||||||
result["target_id"],
|
result["target_id"],
|
||||||
result["relationship_name"],
|
result["relationship_name"],
|
||||||
result["properties"]
|
result["properties"],
|
||||||
)
|
)
|
||||||
for result in edges_result
|
for result in edges_result
|
||||||
]
|
]
|
||||||
|
|
@ -679,9 +691,11 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
"num_nodes": num_nodes,
|
"num_nodes": num_nodes,
|
||||||
"num_edges": num_edges,
|
"num_edges": num_edges,
|
||||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
||||||
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1)) if num_nodes != 0 else None,
|
"edge_density": num_edges * 1.0 / (num_nodes * (num_nodes - 1))
|
||||||
|
if num_nodes != 0
|
||||||
|
else None,
|
||||||
"num_connected_components": num_cluster,
|
"num_connected_components": num_cluster,
|
||||||
"sizes_of_connected_components": list_clsuter_size
|
"sizes_of_connected_components": list_clsuter_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
optional_metrics = {
|
optional_metrics = {
|
||||||
|
|
@ -692,7 +706,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
}
|
}
|
||||||
|
|
||||||
if include_optional:
|
if include_optional:
|
||||||
optional_metrics['num_selfloops'] = await self._count_self_loops()
|
optional_metrics["num_selfloops"] = await self._count_self_loops()
|
||||||
# Unsupported due to long-running queries when computing the shortest path for each node in the graph:
|
# Unsupported due to long-running queries when computing the shortest path for each node in the graph:
|
||||||
# optional_metrics['diameter']
|
# optional_metrics['diameter']
|
||||||
# optional_metrics['avg_shortest_path_length']
|
# optional_metrics['avg_shortest_path_length']
|
||||||
|
|
@ -732,9 +746,11 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
result = await self.query(query, params)
|
result = await self.query(query, params)
|
||||||
|
|
||||||
if result and len(result) > 0:
|
if result and len(result) > 0:
|
||||||
edge_exists = result.pop().get('edge_exists', False)
|
edge_exists = result.pop().get("edge_exists", False)
|
||||||
logger.debug(f"Edge existence check for "
|
logger.debug(
|
||||||
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}")
|
f"Edge existence check for "
|
||||||
|
f"{source_id} -[{relationship_name}]-> {target_id}: {edge_exists}"
|
||||||
|
)
|
||||||
return edge_exists
|
return edge_exists
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
@ -863,10 +879,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
RETURN predecessor
|
RETURN predecessor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await self.query(
|
results = await self.query(query, {"node_id": node_id})
|
||||||
query,
|
|
||||||
{"node_id": node_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
return [result["predecessor"] for result in results]
|
return [result["predecessor"] for result in results]
|
||||||
|
|
||||||
|
|
@ -893,14 +906,10 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
RETURN successor
|
RETURN successor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await self.query(
|
results = await self.query(query, {"node_id": node_id})
|
||||||
query,
|
|
||||||
{"node_id": node_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
return [result["successor"] for result in results]
|
return [result["successor"] for result in results]
|
||||||
|
|
||||||
|
|
||||||
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
||||||
"""
|
"""
|
||||||
Get all neighboring nodes connected to the specified node.
|
Get all neighboring nodes connected to the specified node.
|
||||||
|
|
@ -926,11 +935,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
|
|
||||||
# Format neighbors as NodeData objects
|
# Format neighbors as NodeData objects
|
||||||
neighbors = [
|
neighbors = [
|
||||||
{
|
{"id": neighbor["neighbor_id"], **neighbor["properties"]} for neighbor in result
|
||||||
"id": neighbor["neighbor_id"],
|
|
||||||
**neighbor["properties"]
|
|
||||||
}
|
|
||||||
for neighbor in result
|
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}")
|
logger.debug(f"Retrieved {len(neighbors)} neighbors for node: {node_id}")
|
||||||
|
|
@ -942,9 +947,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
raise Exception(f"Failed to get neighbors: {error_msg}") from e
|
raise Exception(f"Failed to get neighbors: {error_msg}") from e
|
||||||
|
|
||||||
async def get_nodeset_subgraph(
|
async def get_nodeset_subgraph(
|
||||||
self,
|
self, node_type: Type[Any], node_name: List[str]
|
||||||
node_type: Type[Any],
|
|
||||||
node_name: List[str]
|
|
||||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||||
"""
|
"""
|
||||||
Fetch a subgraph consisting of a specific set of nodes and their relationships.
|
Fetch a subgraph consisting of a specific set of nodes and their relationships.
|
||||||
|
|
@ -987,10 +990,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
}}] AS rawRels
|
}}] AS rawRels
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {"names": node_name, "type": node_type.__name__}
|
||||||
"names": node_name,
|
|
||||||
"type": node_type.__name__
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await self.query(query, params)
|
result = await self.query(query, params)
|
||||||
|
|
||||||
|
|
@ -1002,18 +1002,14 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
raw_rels = result[0]["rawRels"]
|
raw_rels = result[0]["rawRels"]
|
||||||
|
|
||||||
# Format nodes as (node_id, properties) tuples
|
# Format nodes as (node_id, properties) tuples
|
||||||
nodes = [
|
nodes = [(n["id"], n["properties"]) for n in raw_nodes]
|
||||||
(n["id"], n["properties"])
|
|
||||||
for n in raw_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
# Format edges as (source_id, target_id, relationship_name, properties) tuples
|
||||||
edges = [
|
edges = [(r["source_id"], r["target_id"], r["type"], r["properties"]) for r in raw_rels]
|
||||||
(r["source_id"], r["target_id"], r["type"], r["properties"])
|
|
||||||
for r in raw_rels
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}")
|
logger.debug(
|
||||||
|
f"Retrieved subgraph with {len(nodes)} nodes and {len(edges)} edges for type {node_type.__name__}"
|
||||||
|
)
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1055,18 +1051,12 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
# Return as (source_node, relationship, target_node)
|
# Return as (source_node, relationship, target_node)
|
||||||
connections.append(
|
connections.append(
|
||||||
(
|
(
|
||||||
{
|
{"id": record["source_id"], **record["source_props"]},
|
||||||
"id": record["source_id"],
|
|
||||||
**record["source_props"]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"relationship_name": record["relationship_name"],
|
"relationship_name": record["relationship_name"],
|
||||||
**record["relationship_props"]
|
**record["relationship_props"],
|
||||||
},
|
},
|
||||||
{
|
{"id": record["target_id"], **record["target_props"]},
|
||||||
"id": record["target_id"],
|
|
||||||
**record["target_props"]
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1078,10 +1068,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
logger.error(f"Failed to get connections for node {node_id}: {error_msg}")
|
logger.error(f"Failed to get connections for node {node_id}: {error_msg}")
|
||||||
raise Exception(f"Failed to get connections: {error_msg}") from e
|
raise Exception(f"Failed to get connections: {error_msg}") from e
|
||||||
|
|
||||||
|
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str):
|
||||||
async def remove_connection_to_predecessors_of(
|
|
||||||
self, node_ids: list[str], edge_label: str
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Remove connections (edges) to all predecessors of specified nodes based on edge label.
|
Remove connections (edges) to all predecessors of specified nodes based on edge label.
|
||||||
|
|
||||||
|
|
@ -1101,10 +1088,7 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
params = {"node_ids": node_ids}
|
params = {"node_ids": node_ids}
|
||||||
await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
|
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str):
|
||||||
async def remove_connection_to_successors_of(
|
|
||||||
self, node_ids: list[str], edge_label: str
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Remove connections (edges) to all successors of specified nodes based on edge label.
|
Remove connections (edges) to all successors of specified nodes based on edge label.
|
||||||
|
|
||||||
|
|
@ -1124,7 +1108,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
params = {"node_ids": node_ids}
|
params = {"node_ids": node_ids}
|
||||||
await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
|
|
||||||
async def get_node_labels_string(self):
|
async def get_node_labels_string(self):
|
||||||
"""
|
"""
|
||||||
Fetch all node labels from the database and return them as a formatted string.
|
Fetch all node labels from the database and return them as a formatted string.
|
||||||
|
|
@ -1138,7 +1121,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
-------
|
-------
|
||||||
ValueError: If no node labels are found in the database.
|
ValueError: If no node labels are found in the database.
|
||||||
"""
|
"""
|
||||||
node_labels_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
|
node_labels_query = (
|
||||||
|
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.nodeLabels as labels "
|
||||||
|
)
|
||||||
node_labels_result = await self.query(node_labels_query)
|
node_labels_result = await self.query(node_labels_query)
|
||||||
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
|
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
|
||||||
|
|
||||||
|
|
@ -1156,7 +1141,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
|
|
||||||
A formatted string of relationship types.
|
A formatted string of relationship types.
|
||||||
"""
|
"""
|
||||||
relationship_types_query = "CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
|
relationship_types_query = (
|
||||||
|
"CALL neptune.graph.pg_schema() YIELD schema RETURN schema.edgeLabels as relationships "
|
||||||
|
)
|
||||||
relationship_types_result = await self.query(relationship_types_query)
|
relationship_types_result = await self.query(relationship_types_query)
|
||||||
relationship_types = (
|
relationship_types = (
|
||||||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||||||
|
|
@ -1172,7 +1159,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return relationship_types_undirected_str
|
return relationship_types_undirected_str
|
||||||
|
|
||||||
|
|
||||||
async def drop_graph(self, graph_name="myGraph"):
|
async def drop_graph(self, graph_name="myGraph"):
|
||||||
"""
|
"""
|
||||||
Drop an existing graph from the database based on its name.
|
Drop an existing graph from the database based on its name.
|
||||||
|
|
@ -1208,7 +1194,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def project_entire_graph(self, graph_name="myGraph"):
|
async def project_entire_graph(self, graph_name="myGraph"):
|
||||||
"""
|
"""
|
||||||
Project all node labels and relationship types into an in-memory graph using GDS.
|
Project all node labels and relationship types into an in-memory graph using GDS.
|
||||||
|
|
@ -1280,7 +1265,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
|
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
|
|
||||||
async def get_degree_one_nodes(self, node_type: str):
|
async def get_degree_one_nodes(self, node_type: str):
|
||||||
"""
|
"""
|
||||||
Fetch nodes of a specified type that have exactly one connection.
|
Fetch nodes of a specified type that have exactly one connection.
|
||||||
|
|
@ -1361,8 +1345,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
result = await self.query(query, {"content_hash": content_hash})
|
result = await self.query(query, {"content_hash": content_hash})
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_model_independent_graph_data(self):
|
async def _get_model_independent_graph_data(self):
|
||||||
"""
|
"""
|
||||||
Retrieve the basic graph data without considering the model specifics, returning nodes
|
Retrieve the basic graph data without considering the model specifics, returning nodes
|
||||||
|
|
@ -1380,8 +1362,8 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
RETURN nodeCount AS numVertices, count(r) AS numEdges
|
RETURN nodeCount AS numVertices, count(r) AS numEdges
|
||||||
"""
|
"""
|
||||||
query_response = await self.query(query_string)
|
query_response = await self.query(query_string)
|
||||||
num_nodes = query_response[0].get('numVertices')
|
num_nodes = query_response[0].get("numVertices")
|
||||||
num_edges = query_response[0].get('numEdges')
|
num_edges = query_response[0].get("numEdges")
|
||||||
|
|
||||||
return (num_nodes, num_edges)
|
return (num_nodes, num_edges)
|
||||||
|
|
||||||
|
|
@ -1437,4 +1419,9 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_relationship_to_edge(relationship: dict) -> EdgeData:
|
def _convert_relationship_to_edge(relationship: dict) -> EdgeData:
|
||||||
return relationship["source_id"], relationship["target_id"], relationship["relationship_name"], relationship["properties"]
|
return (
|
||||||
|
relationship["source_id"],
|
||||||
|
relationship["target_id"],
|
||||||
|
relationship["relationship_name"],
|
||||||
|
relationship["properties"],
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,106 +2,114 @@
|
||||||
|
|
||||||
This module defines custom exceptions for Neptune Analytics operations.
|
This module defines custom exceptions for Neptune Analytics operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from cognee.exceptions import CogneeApiError
|
from cognee.exceptions import CogneeApiError
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsError(CogneeApiError):
|
class NeptuneAnalyticsError(CogneeApiError):
|
||||||
"""Base exception for Neptune Analytics operations."""
|
"""Base exception for Neptune Analytics operations."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "Neptune Analytics error.",
|
message: str = "Neptune Analytics error.",
|
||||||
name: str = "NeptuneAnalyticsError",
|
name: str = "NeptuneAnalyticsError",
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when connection to Neptune Analytics fails."""
|
"""Exception raised when connection to Neptune Analytics fails."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "Unable to connect to Neptune Analytics. Please check the endpoint and network connectivity.",
|
message: str = "Unable to connect to Neptune Analytics. Please check the endpoint and network connectivity.",
|
||||||
name: str = "NeptuneAnalyticsConnectionError",
|
name: str = "NeptuneAnalyticsConnectionError",
|
||||||
status_code=status.HTTP_404_NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsQueryError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsQueryError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when a query execution fails."""
|
"""Exception raised when a query execution fails."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "The query execution failed due to invalid syntax or semantic issues.",
|
message: str = "The query execution failed due to invalid syntax or semantic issues.",
|
||||||
name: str = "NeptuneAnalyticsQueryError",
|
name: str = "NeptuneAnalyticsQueryError",
|
||||||
status_code=status.HTTP_400_BAD_REQUEST
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when authentication with Neptune Analytics fails."""
|
"""Exception raised when authentication with Neptune Analytics fails."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "Authentication with Neptune Analytics failed. Please verify your credentials.",
|
message: str = "Authentication with Neptune Analytics failed. Please verify your credentials.",
|
||||||
name: str = "NeptuneAnalyticsAuthenticationError",
|
name: str = "NeptuneAnalyticsAuthenticationError",
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when Neptune Analytics configuration is invalid."""
|
"""Exception raised when Neptune Analytics configuration is invalid."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "Neptune Analytics configuration is invalid or incomplete. Please review your setup.",
|
message: str = "Neptune Analytics configuration is invalid or incomplete. Please review your setup.",
|
||||||
name: str = "NeptuneAnalyticsConfigurationError",
|
name: str = "NeptuneAnalyticsConfigurationError",
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when a Neptune Analytics operation times out."""
|
"""Exception raised when a Neptune Analytics operation times out."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "The operation timed out while communicating with Neptune Analytics.",
|
message: str = "The operation timed out while communicating with Neptune Analytics.",
|
||||||
name: str = "NeptuneAnalyticsTimeoutError",
|
name: str = "NeptuneAnalyticsTimeoutError",
|
||||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when requests are throttled by Neptune Analytics."""
|
"""Exception raised when requests are throttled by Neptune Analytics."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "Request was throttled by Neptune Analytics due to exceeding rate limits.",
|
message: str = "Request was throttled by Neptune Analytics due to exceeding rate limits.",
|
||||||
name: str = "NeptuneAnalyticsThrottlingError",
|
name: str = "NeptuneAnalyticsThrottlingError",
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when a Neptune Analytics resource is not found."""
|
"""Exception raised when a Neptune Analytics resource is not found."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "The requested Neptune Analytics resource could not be found.",
|
message: str = "The requested Neptune Analytics resource could not be found.",
|
||||||
name: str = "NeptuneAnalyticsResourceNotFoundError",
|
name: str = "NeptuneAnalyticsResourceNotFoundError",
|
||||||
status_code=status.HTTP_404_NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError):
|
class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError):
|
||||||
"""Exception raised when invalid parameters are provided to Neptune Analytics."""
|
"""Exception raised when invalid parameters are provided to Neptune Analytics."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "One or more parameters provided to Neptune Analytics are invalid or missing.",
|
message: str = "One or more parameters provided to Neptune Analytics are invalid or missing.",
|
||||||
name: str = "NeptuneAnalyticsInvalidParameterError",
|
name: str = "NeptuneAnalyticsInvalidParameterError",
|
||||||
status_code=status.HTTP_400_BAD_REQUEST
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
):
|
):
|
||||||
super().__init__(message, name, status_code)
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,15 +38,17 @@ def parse_neptune_url(url: str) -> Tuple[str, str]:
|
||||||
if parsed.scheme != "neptune-graph":
|
if parsed.scheme != "neptune-graph":
|
||||||
raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'")
|
raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'")
|
||||||
|
|
||||||
graph_id = parsed.hostname or parsed.path.lstrip('/')
|
graph_id = parsed.hostname or parsed.path.lstrip("/")
|
||||||
if not graph_id:
|
if not graph_id:
|
||||||
raise ValueError("Graph ID not found in URL")
|
raise ValueError("Graph ID not found in URL")
|
||||||
|
|
||||||
# Extract region from query parameters
|
# Extract region from query parameters
|
||||||
region = "us-east-1" # default region
|
region = "us-east-1" # default region
|
||||||
if parsed.query:
|
if parsed.query:
|
||||||
query_params = dict(param.split('=') for param in parsed.query.split('&') if '=' in param)
|
query_params = dict(
|
||||||
region = query_params.get('region', region)
|
param.split("=") for param in parsed.query.split("&") if "=" in param
|
||||||
|
)
|
||||||
|
region = query_params.get("region", region)
|
||||||
|
|
||||||
return graph_id, region
|
return graph_id, region
|
||||||
|
|
||||||
|
|
@ -73,7 +75,7 @@ def validate_graph_id(graph_id: str) -> bool:
|
||||||
|
|
||||||
# Neptune Analytics graph IDs should be alphanumeric with hyphens
|
# Neptune Analytics graph IDs should be alphanumeric with hyphens
|
||||||
# and between 1-63 characters
|
# and between 1-63 characters
|
||||||
pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$'
|
pattern = r"^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$"
|
||||||
return bool(re.match(pattern, graph_id))
|
return bool(re.match(pattern, graph_id))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -93,7 +95,7 @@ def validate_aws_region(region: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# AWS regions follow the pattern: us-east-1, eu-west-1, etc.
|
# AWS regions follow the pattern: us-east-1, eu-west-1, etc.
|
||||||
pattern = r'^[a-z]{2,3}-[a-z]+-\d+$'
|
pattern = r"^[a-z]{2,3}-[a-z]+-\d+$"
|
||||||
return bool(re.match(pattern, region))
|
return bool(re.match(pattern, region))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -103,7 +105,7 @@ def build_neptune_config(
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
aws_session_token: Optional[str] = None,
|
aws_session_token: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Build a configuration dictionary for Neptune Analytics connection.
|
Build a configuration dictionary for Neptune Analytics connection.
|
||||||
|
|
@ -194,6 +196,7 @@ def format_neptune_error(error: Exception) -> str:
|
||||||
|
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
|
|
||||||
def get_default_query_timeout() -> int:
|
def get_default_query_timeout() -> int:
|
||||||
"""
|
"""
|
||||||
Get the default query timeout for Neptune Analytics operations.
|
Get the default query timeout for Neptune Analytics operations.
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredRes
|
||||||
|
|
||||||
logger = get_logger("NeptuneAnalyticsAdapter")
|
logger = get_logger("NeptuneAnalyticsAdapter")
|
||||||
|
|
||||||
|
|
||||||
class IndexSchema(DataPoint):
|
class IndexSchema(DataPoint):
|
||||||
"""
|
"""
|
||||||
Represents a schema for an index data point containing an ID and text.
|
Represents a schema for an index data point containing an ID and text.
|
||||||
|
|
@ -27,12 +28,15 @@ class IndexSchema(DataPoint):
|
||||||
- metadata: A dictionary with default index fields for the schema, currently configured
|
- metadata: A dictionary with default index fields for the schema, currently configured
|
||||||
to include 'text'.
|
to include 'text'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
text: str
|
text: str
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
|
||||||
NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://"
|
NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://"
|
||||||
|
|
||||||
|
|
||||||
class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
Hybrid adapter that combines Neptune Analytics Vector and Graph functionality.
|
Hybrid adapter that combines Neptune Analytics Vector and Graph functionality.
|
||||||
|
|
@ -48,14 +52,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
_TOPK_UPPER_BOUND = 10
|
_TOPK_UPPER_BOUND = 10
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
embedding_engine: Optional[EmbeddingEngine] = None,
|
embedding_engine: Optional[EmbeddingEngine] = None,
|
||||||
region: Optional[str] = None,
|
region: Optional[str] = None,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
aws_session_token: Optional[str] = None,
|
aws_session_token: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Neptune Analytics hybrid adapter.
|
Initialize the Neptune Analytics hybrid adapter.
|
||||||
|
|
||||||
|
|
@ -74,12 +78,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
region=region,
|
region=region,
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_session_token=aws_session_token
|
aws_session_token=aws_session_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add vector-specific attributes
|
# Add vector-specific attributes
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
logger.info(f"Initialized Neptune Analytics hybrid adapter for graph: \"{graph_id}\" in region: \"{self.region}\"")
|
logger.info(
|
||||||
|
f'Initialized Neptune Analytics hybrid adapter for graph: "{graph_id}" in region: "{self.region}"'
|
||||||
|
)
|
||||||
|
|
||||||
# VectorDBInterface methods implementation
|
# VectorDBInterface methods implementation
|
||||||
|
|
||||||
|
|
@ -161,7 +167,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
|
|
||||||
# Fetch embeddings
|
# Fetch embeddings
|
||||||
texts = [DataPoint.get_embeddable_data(t) for t in data_points]
|
texts = [DataPoint.get_embeddable_data(t) for t in data_points]
|
||||||
data_vectors = (await self.embedding_engine.embed_text(texts))
|
data_vectors = await self.embedding_engine.embed_text(texts)
|
||||||
|
|
||||||
for index, data_point in enumerate(data_points):
|
for index, data_point in enumerate(data_points):
|
||||||
node_id = data_point.id
|
node_id = data_point.id
|
||||||
|
|
@ -172,23 +178,24 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
properties = self._serialize_properties(data_point.model_dump())
|
properties = self._serialize_properties(data_point.model_dump())
|
||||||
properties[self._COLLECTION_PREFIX] = collection_name
|
properties[self._COLLECTION_PREFIX] = collection_name
|
||||||
params = dict(
|
params = dict(
|
||||||
node_id = str(node_id),
|
node_id=str(node_id),
|
||||||
properties = properties,
|
properties=properties,
|
||||||
embedding = data_vector,
|
embedding=data_vector,
|
||||||
collection_name = collection_name
|
collection_name=collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compose the query and send
|
# Compose the query and send
|
||||||
query_string = (
|
query_string = (
|
||||||
f"MERGE (n "
|
f"MERGE (n "
|
||||||
f":{self._VECTOR_NODE_LABEL} "
|
f":{self._VECTOR_NODE_LABEL} "
|
||||||
f" {{`~id`: $node_id}}) "
|
f" {{`~id`: $node_id}}) "
|
||||||
f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
|
f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
|
||||||
f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
|
f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
|
||||||
f"WITH n, $embedding AS embedding "
|
f"WITH n, $embedding AS embedding "
|
||||||
f"CALL neptune.algo.vectors.upsert(n, embedding) "
|
f"CALL neptune.algo.vectors.upsert(n, embedding) "
|
||||||
f"YIELD success "
|
f"YIELD success "
|
||||||
f"RETURN success ")
|
f"RETURN success "
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._client.query(query_string, params)
|
self._client.query(query_string, params)
|
||||||
|
|
@ -208,10 +215,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
# Do the fetch for each node
|
# Do the fetch for each node
|
||||||
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
||||||
query_string = (f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
|
query_string = (
|
||||||
f"WHERE id(n) in $node_ids AND "
|
f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
|
||||||
f"n.{self._COLLECTION_PREFIX} = $collection_name "
|
f"WHERE id(n) in $node_ids AND "
|
||||||
f"RETURN n as payload ")
|
f"n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||||
|
f"RETURN n as payload "
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self._client.query(query_string, params)
|
result = self._client.query(query_string, params)
|
||||||
|
|
@ -259,7 +268,8 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
|
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
|
||||||
"Defaulting to limit=10.", limit
|
"Defaulting to limit=10.",
|
||||||
|
limit,
|
||||||
)
|
)
|
||||||
limit = self._TOPK_UPPER_BOUND
|
limit = self._TOPK_UPPER_BOUND
|
||||||
|
|
||||||
|
|
@ -272,7 +282,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
elif query_vector:
|
elif query_vector:
|
||||||
embedding = query_vector
|
embedding = query_vector
|
||||||
else:
|
else:
|
||||||
data_vectors = (await self.embedding_engine.embed_text([query_text]))
|
data_vectors = await self.embedding_engine.embed_text([query_text])
|
||||||
embedding = data_vectors[0]
|
embedding = data_vectors[0]
|
||||||
|
|
||||||
# Compose the parameters map
|
# Compose the parameters map
|
||||||
|
|
@ -305,9 +315,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query_response = self._client.query(query_string, params)
|
query_response = self._client.query(query_string, params)
|
||||||
return [self._get_scored_result(
|
return [self._get_scored_result(item=item, with_score=True) for item in query_response]
|
||||||
item = item, with_score = True
|
|
||||||
) for item in query_response]
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._na_exception_handler(e, query_string)
|
self._na_exception_handler(e, query_string)
|
||||||
|
|
||||||
|
|
@ -332,11 +340,13 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
self._validate_embedding_engine()
|
self._validate_embedding_engine()
|
||||||
|
|
||||||
# Convert text to embedding array in batch
|
# Convert text to embedding array in batch
|
||||||
data_vectors = (await self.embedding_engine.embed_text(query_texts))
|
data_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
return await asyncio.gather(*[
|
return await asyncio.gather(
|
||||||
self.search(collection_name, None, vector, limit, with_vectors)
|
*[
|
||||||
for vector in data_vectors
|
self.search(collection_name, None, vector, limit, with_vectors)
|
||||||
])
|
for vector in data_vectors
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||||
"""
|
"""
|
||||||
|
|
@ -350,10 +360,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
- data_point_ids (list[str]): A list of IDs of the data points to delete.
|
- data_point_ids (list[str]): A list of IDs of the data points to delete.
|
||||||
"""
|
"""
|
||||||
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
||||||
query_string = (f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
query_string = (
|
||||||
f"WHERE id(n) IN $node_ids "
|
f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
||||||
f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
|
f"WHERE id(n) IN $node_ids "
|
||||||
f"DETACH DELETE n")
|
f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||||
|
f"DETACH DELETE n"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
self._client.query(query_string, params)
|
self._client.query(query_string, params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -370,7 +382,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||||
|
|
||||||
async def index_data_points(
|
async def index_data_points(
|
||||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Indexes a list of data points into Neptune Analytics by creating them as nodes.
|
Indexes a list of data points into Neptune Analytics by creating them as nodes.
|
||||||
|
|
@ -402,29 +414,28 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
Remove obsolete or unnecessary data from the database.
|
Remove obsolete or unnecessary data from the database.
|
||||||
"""
|
"""
|
||||||
# Run actual truncate
|
# Run actual truncate
|
||||||
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) DETACH DELETE n")
|
||||||
f"DETACH DELETE n")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_scored_result(item: dict, with_vector: bool = False, with_score: bool = False) -> ScoredResult:
|
def _get_scored_result(
|
||||||
|
item: dict, with_vector: bool = False, with_score: bool = False
|
||||||
|
) -> ScoredResult:
|
||||||
"""
|
"""
|
||||||
Util method to simplify the object creation of ScoredResult base on incoming NX payload response.
|
Util method to simplify the object creation of ScoredResult base on incoming NX payload response.
|
||||||
"""
|
"""
|
||||||
return ScoredResult(
|
return ScoredResult(
|
||||||
id=item.get('payload').get('~id'),
|
id=item.get("payload").get("~id"),
|
||||||
payload=item.get('payload').get('~properties'),
|
payload=item.get("payload").get("~properties"),
|
||||||
score=item.get('score') if with_score else 0,
|
score=item.get("score") if with_score else 0,
|
||||||
vector=item.get('embedding') if with_vector else None
|
vector=item.get("embedding") if with_vector else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _na_exception_handler(self, ex, query_string: str):
|
def _na_exception_handler(self, ex, query_string: str):
|
||||||
"""
|
"""
|
||||||
Generic exception handler for NA langchain.
|
Generic exception handler for NA langchain.
|
||||||
"""
|
"""
|
||||||
logger.error(
|
logger.error("Neptune Analytics query failed: %s | Query: [%s]", ex, query_string)
|
||||||
"Neptune Analytics query failed: %s | Query: [%s]", ex, query_string
|
|
||||||
)
|
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
def _validate_embedding_engine(self):
|
def _validate_embedding_engine(self):
|
||||||
|
|
@ -433,4 +444,6 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
:raises: ValueError if this object does not have a valid embedding_engine
|
:raises: ValueError if this object does not have a valid embedding_engine
|
||||||
"""
|
"""
|
||||||
if self.embedding_engine is None:
|
if self.embedding_engine is None:
|
||||||
raise ValueError("Neptune Analytics requires an embedder defined to make vector operations")
|
raise ValueError(
|
||||||
|
"Neptune Analytics requires an embedder defined to make vector operations"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -125,9 +125,15 @@ def create_vector_engine(
|
||||||
if not vector_db_url:
|
if not vector_db_url:
|
||||||
raise EnvironmentError("Missing Neptune endpoint.")
|
raise EnvironmentError("Missing Neptune endpoint.")
|
||||||
|
|
||||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
|
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||||
|
NeptuneAnalyticsAdapter,
|
||||||
|
NEPTUNE_ANALYTICS_ENDPOINT_URL,
|
||||||
|
)
|
||||||
|
|
||||||
if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
||||||
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
|
raise ValueError(
|
||||||
|
f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'"
|
||||||
|
)
|
||||||
|
|
||||||
graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from cognee.modules.data.processing.document_types import TextDocument
|
||||||
|
|
||||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
graph_id = os.getenv('GRAPH_ID', "")
|
graph_id = os.getenv("GRAPH_ID", "")
|
||||||
|
|
||||||
na_adapter = NeptuneGraphDB(graph_id)
|
na_adapter = NeptuneGraphDB(graph_id)
|
||||||
|
|
||||||
|
|
@ -24,33 +24,33 @@ def setup():
|
||||||
# stored in Amazon S3.
|
# stored in Amazon S3.
|
||||||
|
|
||||||
document = TextDocument(
|
document = TextDocument(
|
||||||
name='text_test.txt',
|
name="text_test.txt",
|
||||||
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text_test.txt',
|
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt",
|
||||||
external_metadata='{}',
|
external_metadata="{}",
|
||||||
mime_type='text/plain'
|
mime_type="text/plain",
|
||||||
)
|
)
|
||||||
document_chunk = DocumentChunk(
|
document_chunk = DocumentChunk(
|
||||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||||
chunk_size=187,
|
chunk_size=187,
|
||||||
chunk_index=0,
|
chunk_index=0,
|
||||||
cut_type='paragraph_end',
|
cut_type="paragraph_end",
|
||||||
is_part_of=document,
|
is_part_of=document,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_database = EntityType(name='graph database', description='graph database')
|
graph_database = EntityType(name="graph database", description="graph database")
|
||||||
neptune_analytics_entity = Entity(
|
neptune_analytics_entity = Entity(
|
||||||
name='neptune analytics',
|
name="neptune analytics",
|
||||||
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
|
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
||||||
)
|
)
|
||||||
neptune_database_entity = Entity(
|
neptune_database_entity = Entity(
|
||||||
name='amazon neptune database',
|
name="amazon neptune database",
|
||||||
description='A popular managed graph database that complements Neptune Analytics.',
|
description="A popular managed graph database that complements Neptune Analytics.",
|
||||||
)
|
)
|
||||||
|
|
||||||
storage = EntityType(name='storage', description='storage')
|
storage = EntityType(name="storage", description="storage")
|
||||||
storage_entity = Entity(
|
storage_entity = Entity(
|
||||||
name='amazon s3',
|
name="amazon s3",
|
||||||
description='A storage service provided by Amazon Web Services that allows storing graph data.',
|
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes_data = [
|
nodes_data = [
|
||||||
|
|
@ -67,37 +67,37 @@ def setup():
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(storage_entity.id),
|
str(storage_entity.id),
|
||||||
'contains',
|
"contains",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(storage_entity.id),
|
str(storage_entity.id),
|
||||||
str(storage.id),
|
str(storage.id),
|
||||||
'is_a',
|
"is_a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(neptune_database_entity.id),
|
str(neptune_database_entity.id),
|
||||||
'contains',
|
"contains",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(neptune_database_entity.id),
|
str(neptune_database_entity.id),
|
||||||
str(graph_database.id),
|
str(graph_database.id),
|
||||||
'is_a',
|
"is_a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(document.id),
|
str(document.id),
|
||||||
'is_part_of',
|
"is_part_of",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(neptune_analytics_entity.id),
|
str(neptune_analytics_entity.id),
|
||||||
'contains',
|
"contains",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(neptune_analytics_entity.id),
|
str(neptune_analytics_entity.id),
|
||||||
str(graph_database.id),
|
str(graph_database.id),
|
||||||
'is_a',
|
"is_a",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -155,42 +155,44 @@ async def pipeline_method():
|
||||||
print("------NEIGHBORING NODES-------")
|
print("------NEIGHBORING NODES-------")
|
||||||
center_node = nodes[2]
|
center_node = nodes[2]
|
||||||
neighbors = await na_adapter.get_neighbors(str(center_node.id))
|
neighbors = await na_adapter.get_neighbors(str(center_node.id))
|
||||||
print(f"found {len(neighbors)} neighbors for node \"{center_node.name}\"")
|
print(f'found {len(neighbors)} neighbors for node "{center_node.name}"')
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
print(neighbor)
|
print(neighbor)
|
||||||
|
|
||||||
print("------NEIGHBORING EDGES-------")
|
print("------NEIGHBORING EDGES-------")
|
||||||
center_node = nodes[2]
|
center_node = nodes[2]
|
||||||
neighbouring_edges = await na_adapter.get_edges(str(center_node.id))
|
neighbouring_edges = await na_adapter.get_edges(str(center_node.id))
|
||||||
print(f"found {len(neighbouring_edges)} edges neighbouring node \"{center_node.name}\"")
|
print(f'found {len(neighbouring_edges)} edges neighbouring node "{center_node.name}"')
|
||||||
for edge in neighbouring_edges:
|
for edge in neighbouring_edges:
|
||||||
print(edge)
|
print(edge)
|
||||||
|
|
||||||
print("------GET CONNECTIONS (SOURCE NODE)-------")
|
print("------GET CONNECTIONS (SOURCE NODE)-------")
|
||||||
document_chunk_node = nodes[0]
|
document_chunk_node = nodes[0]
|
||||||
connections = await na_adapter.get_connections(str(document_chunk_node.id))
|
connections = await na_adapter.get_connections(str(document_chunk_node.id))
|
||||||
print(f"found {len(connections)} connections for node \"{document_chunk_node.type}\"")
|
print(f'found {len(connections)} connections for node "{document_chunk_node.type}"')
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
src, relationship, tgt = connection
|
src, relationship, tgt = connection
|
||||||
src = src.get("name", src.get("type", "unknown"))
|
src = src.get("name", src.get("type", "unknown"))
|
||||||
relationship = relationship["relationship_name"]
|
relationship = relationship["relationship_name"]
|
||||||
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
||||||
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
|
print(f'"{src}"-[{relationship}]->"{tgt}"')
|
||||||
|
|
||||||
print("------GET CONNECTIONS (TARGET NODE)-------")
|
print("------GET CONNECTIONS (TARGET NODE)-------")
|
||||||
connections = await na_adapter.get_connections(str(center_node.id))
|
connections = await na_adapter.get_connections(str(center_node.id))
|
||||||
print(f"found {len(connections)} connections for node \"{center_node.name}\"")
|
print(f'found {len(connections)} connections for node "{center_node.name}"')
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
src, relationship, tgt = connection
|
src, relationship, tgt = connection
|
||||||
src = src.get("name", src.get("type", "unknown"))
|
src = src.get("name", src.get("type", "unknown"))
|
||||||
relationship = relationship["relationship_name"]
|
relationship = relationship["relationship_name"]
|
||||||
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
||||||
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
|
print(f'"{src}"-[{relationship}]->"{tgt}"')
|
||||||
|
|
||||||
print("------SUBGRAPH-------")
|
print("------SUBGRAPH-------")
|
||||||
node_names = ["neptune analytics", "amazon neptune database"]
|
node_names = ["neptune analytics", "amazon neptune database"]
|
||||||
subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names)
|
subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names)
|
||||||
print(f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}")
|
print(
|
||||||
|
f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}"
|
||||||
|
)
|
||||||
for subgraph_node in subgraph_nodes:
|
for subgraph_node in subgraph_nodes:
|
||||||
print(subgraph_node)
|
print(subgraph_node)
|
||||||
for subgraph_edge in subgraph_edges:
|
for subgraph_edge in subgraph_edges:
|
||||||
|
|
@ -199,17 +201,17 @@ async def pipeline_method():
|
||||||
print("------STAT-------")
|
print("------STAT-------")
|
||||||
stat = await na_adapter.get_graph_metrics(include_optional=True)
|
stat = await na_adapter.get_graph_metrics(include_optional=True)
|
||||||
assert type(stat) is dict
|
assert type(stat) is dict
|
||||||
assert stat['num_nodes'] == 7
|
assert stat["num_nodes"] == 7
|
||||||
assert stat['num_edges'] == 7
|
assert stat["num_edges"] == 7
|
||||||
assert stat['mean_degree'] == 2.0
|
assert stat["mean_degree"] == 2.0
|
||||||
assert round(stat['edge_density'], 3) == 0.167
|
assert round(stat["edge_density"], 3) == 0.167
|
||||||
assert stat['num_connected_components'] == [7]
|
assert stat["num_connected_components"] == [7]
|
||||||
assert stat['sizes_of_connected_components'] == 1
|
assert stat["sizes_of_connected_components"] == 1
|
||||||
assert stat['num_selfloops'] == 0
|
assert stat["num_selfloops"] == 0
|
||||||
# Unsupported optional metrics
|
# Unsupported optional metrics
|
||||||
assert stat['diameter'] == -1
|
assert stat["diameter"] == -1
|
||||||
assert stat['avg_shortest_path_length'] == -1
|
assert stat["avg_shortest_path_length"] == -1
|
||||||
assert stat['avg_clustering'] == -1
|
assert stat["avg_clustering"] == -1
|
||||||
|
|
||||||
print("------DELETE-------")
|
print("------DELETE-------")
|
||||||
# delete all nodes and edges:
|
# delete all nodes and edges:
|
||||||
|
|
@ -253,7 +255,9 @@ async def misc_methods():
|
||||||
print(edge_labels)
|
print(edge_labels)
|
||||||
|
|
||||||
print("------Get Filtered Graph-------")
|
print("------Get Filtered Graph-------")
|
||||||
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data([{'name': ['text_test.txt']}])
|
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data(
|
||||||
|
[{"name": ["text_test.txt"]}]
|
||||||
|
)
|
||||||
print(filtered_nodes, filtered_edges)
|
print(filtered_nodes, filtered_edges)
|
||||||
|
|
||||||
print("------Get Degree one nodes-------")
|
print("------Get Degree one nodes-------")
|
||||||
|
|
@ -261,15 +265,13 @@ async def misc_methods():
|
||||||
print(degree_one_nodes)
|
print(degree_one_nodes)
|
||||||
|
|
||||||
print("------Get Doc sub-graph-------")
|
print("------Get Doc sub-graph-------")
|
||||||
doc_sub_graph = await na_adapter.get_document_subgraph('test.txt')
|
doc_sub_graph = await na_adapter.get_document_subgraph("test.txt")
|
||||||
print(doc_sub_graph)
|
print(doc_sub_graph)
|
||||||
|
|
||||||
print("------Fetch and Remove connections (Predecessors)-------")
|
print("------Fetch and Remove connections (Predecessors)-------")
|
||||||
# Fetch test edge
|
# Fetch test edge
|
||||||
(src_id, dest_id, relationship) = edges[0]
|
(src_id, dest_id, relationship) = edges[0]
|
||||||
nodes_predecessors = await na_adapter.get_predecessors(
|
nodes_predecessors = await na_adapter.get_predecessors(node_id=dest_id, edge_label=relationship)
|
||||||
node_id=dest_id, edge_label=relationship
|
|
||||||
)
|
|
||||||
assert len(nodes_predecessors) > 0
|
assert len(nodes_predecessors) > 0
|
||||||
|
|
||||||
await na_adapter.remove_connection_to_predecessors_of(
|
await na_adapter.remove_connection_to_predecessors_of(
|
||||||
|
|
@ -281,25 +283,19 @@ async def misc_methods():
|
||||||
# Return empty after relationship being deleted.
|
# Return empty after relationship being deleted.
|
||||||
assert len(nodes_predecessors_after) == 0
|
assert len(nodes_predecessors_after) == 0
|
||||||
|
|
||||||
|
|
||||||
print("------Fetch and Remove connections (Successors)-------")
|
print("------Fetch and Remove connections (Successors)-------")
|
||||||
_, edges_suc = await na_adapter.get_graph_data()
|
_, edges_suc = await na_adapter.get_graph_data()
|
||||||
(src_id, dest_id, relationship, _) = edges_suc[0]
|
(src_id, dest_id, relationship, _) = edges_suc[0]
|
||||||
|
|
||||||
nodes_successors = await na_adapter.get_successors(
|
nodes_successors = await na_adapter.get_successors(node_id=src_id, edge_label=relationship)
|
||||||
node_id=src_id, edge_label=relationship
|
|
||||||
)
|
|
||||||
assert len(nodes_successors) > 0
|
assert len(nodes_successors) > 0
|
||||||
|
|
||||||
await na_adapter.remove_connection_to_successors_of(
|
await na_adapter.remove_connection_to_successors_of(node_ids=[dest_id], edge_label=relationship)
|
||||||
node_ids=[dest_id], edge_label=relationship
|
|
||||||
)
|
|
||||||
nodes_successors_after = await na_adapter.get_successors(
|
nodes_successors_after = await na_adapter.get_successors(
|
||||||
node_id=src_id, edge_label=relationship
|
node_id=src_id, edge_label=relationship
|
||||||
)
|
)
|
||||||
assert len(nodes_successors_after) == 0
|
assert len(nodes_successors_after) == 0
|
||||||
|
|
||||||
|
|
||||||
# no-op
|
# no-op
|
||||||
await na_adapter.project_entire_graph()
|
await na_adapter.project_entire_graph()
|
||||||
await na_adapter.drop_graph()
|
await na_adapter.drop_graph()
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,13 @@ from cognee.modules.engine.models import Entity, EntityType
|
||||||
from cognee.modules.data.processing.document_types import TextDocument
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter
|
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||||
|
NeptuneAnalyticsAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
graph_id = os.getenv('GRAPH_ID', "")
|
graph_id = os.getenv("GRAPH_ID", "")
|
||||||
|
|
||||||
# get the default embedder
|
# get the default embedder
|
||||||
embedding_engine = get_embedding_engine()
|
embedding_engine = get_embedding_engine()
|
||||||
|
|
@ -23,6 +25,7 @@ collection = "test_collection"
|
||||||
|
|
||||||
logger = get_logger("test_neptune_analytics_hybrid")
|
logger = get_logger("test_neptune_analytics_hybrid")
|
||||||
|
|
||||||
|
|
||||||
def setup_data():
|
def setup_data():
|
||||||
# Define nodes data before the main function
|
# Define nodes data before the main function
|
||||||
# These nodes were defined using openAI from the following prompt:
|
# These nodes were defined using openAI from the following prompt:
|
||||||
|
|
@ -34,33 +37,33 @@ def setup_data():
|
||||||
# stored in Amazon S3.
|
# stored in Amazon S3.
|
||||||
|
|
||||||
document = TextDocument(
|
document = TextDocument(
|
||||||
name='text.txt',
|
name="text.txt",
|
||||||
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text.txt',
|
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text.txt",
|
||||||
external_metadata='{}',
|
external_metadata="{}",
|
||||||
mime_type='text/plain'
|
mime_type="text/plain",
|
||||||
)
|
)
|
||||||
document_chunk = DocumentChunk(
|
document_chunk = DocumentChunk(
|
||||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||||
chunk_size=187,
|
chunk_size=187,
|
||||||
chunk_index=0,
|
chunk_index=0,
|
||||||
cut_type='paragraph_end',
|
cut_type="paragraph_end",
|
||||||
is_part_of=document,
|
is_part_of=document,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_database = EntityType(name='graph database', description='graph database')
|
graph_database = EntityType(name="graph database", description="graph database")
|
||||||
neptune_analytics_entity = Entity(
|
neptune_analytics_entity = Entity(
|
||||||
name='neptune analytics',
|
name="neptune analytics",
|
||||||
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
|
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
||||||
)
|
)
|
||||||
neptune_database_entity = Entity(
|
neptune_database_entity = Entity(
|
||||||
name='amazon neptune database',
|
name="amazon neptune database",
|
||||||
description='A popular managed graph database that complements Neptune Analytics.',
|
description="A popular managed graph database that complements Neptune Analytics.",
|
||||||
)
|
)
|
||||||
|
|
||||||
storage = EntityType(name='storage', description='storage')
|
storage = EntityType(name="storage", description="storage")
|
||||||
storage_entity = Entity(
|
storage_entity = Entity(
|
||||||
name='amazon s3',
|
name="amazon s3",
|
||||||
description='A storage service provided by Amazon Web Services that allows storing graph data.',
|
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes_data = [
|
nodes_data = [
|
||||||
|
|
@ -77,41 +80,42 @@ def setup_data():
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(storage_entity.id),
|
str(storage_entity.id),
|
||||||
'contains',
|
"contains",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(storage_entity.id),
|
str(storage_entity.id),
|
||||||
str(storage.id),
|
str(storage.id),
|
||||||
'is_a',
|
"is_a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(neptune_database_entity.id),
|
str(neptune_database_entity.id),
|
||||||
'contains',
|
"contains",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(neptune_database_entity.id),
|
str(neptune_database_entity.id),
|
||||||
str(graph_database.id),
|
str(graph_database.id),
|
||||||
'is_a',
|
"is_a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(document.id),
|
str(document.id),
|
||||||
'is_part_of',
|
"is_part_of",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(document_chunk.id),
|
str(document_chunk.id),
|
||||||
str(neptune_analytics_entity.id),
|
str(neptune_analytics_entity.id),
|
||||||
'contains',
|
"contains",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
str(neptune_analytics_entity.id),
|
str(neptune_analytics_entity.id),
|
||||||
str(graph_database.id),
|
str(graph_database.id),
|
||||||
'is_a',
|
"is_a",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
return nodes_data, edges_data
|
return nodes_data, edges_data
|
||||||
|
|
||||||
|
|
||||||
async def test_add_graph_then_vector_data():
|
async def test_add_graph_then_vector_data():
|
||||||
logger.info("------test_add_graph_then_vector_data-------")
|
logger.info("------test_add_graph_then_vector_data-------")
|
||||||
(nodes, edges) = setup_data()
|
(nodes, edges) = setup_data()
|
||||||
|
|
@ -134,6 +138,7 @@ async def test_add_graph_then_vector_data():
|
||||||
assert len(edges) == 0
|
assert len(edges) == 0
|
||||||
logger.info("------PASSED-------")
|
logger.info("------PASSED-------")
|
||||||
|
|
||||||
|
|
||||||
async def test_add_vector_then_node_data():
|
async def test_add_vector_then_node_data():
|
||||||
logger.info("------test_add_vector_then_node_data-------")
|
logger.info("------test_add_vector_then_node_data-------")
|
||||||
(nodes, edges) = setup_data()
|
(nodes, edges) = setup_data()
|
||||||
|
|
@ -156,6 +161,7 @@ async def test_add_vector_then_node_data():
|
||||||
assert len(edges) == 0
|
assert len(edges) == 0
|
||||||
logger.info("------PASSED-------")
|
logger.info("------PASSED-------")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data
|
Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data
|
||||||
|
|
@ -165,5 +171,6 @@ def main():
|
||||||
asyncio.run(test_add_graph_then_vector_data())
|
asyncio.run(test_add_graph_then_vector_data())
|
||||||
asyncio.run(test_add_vector_then_node_data())
|
asyncio.run(test_add_vector_then_node_data())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,16 @@ from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, IndexSchema
|
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
||||||
|
NeptuneAnalyticsAdapter,
|
||||||
|
IndexSchema,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
graph_id = os.getenv('GRAPH_ID', "")
|
graph_id = os.getenv("GRAPH_ID", "")
|
||||||
cognee.config.set_vector_db_provider("neptune_analytics")
|
cognee.config.set_vector_db_provider("neptune_analytics")
|
||||||
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
||||||
data_directory_path = str(
|
data_directory_path = str(
|
||||||
|
|
@ -87,6 +90,7 @@ async def main():
|
||||||
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
|
||||||
async def vector_backend_api_test():
|
async def vector_backend_api_test():
|
||||||
cognee.config.set_vector_db_provider("neptune_analytics")
|
cognee.config.set_vector_db_provider("neptune_analytics")
|
||||||
|
|
||||||
|
|
@ -101,7 +105,7 @@ async def vector_backend_api_test():
|
||||||
get_vector_engine()
|
get_vector_engine()
|
||||||
|
|
||||||
# Return a valid engine object with valid URL.
|
# Return a valid engine object with valid URL.
|
||||||
graph_id = os.getenv('GRAPH_ID', "")
|
graph_id = os.getenv("GRAPH_ID", "")
|
||||||
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
||||||
engine = get_vector_engine()
|
engine = get_vector_engine()
|
||||||
assert isinstance(engine, NeptuneAnalyticsAdapter)
|
assert isinstance(engine, NeptuneAnalyticsAdapter)
|
||||||
|
|
@ -133,28 +137,22 @@ async def vector_backend_api_test():
|
||||||
query_text=TEST_TEXT,
|
query_text=TEST_TEXT,
|
||||||
query_vector=None,
|
query_vector=None,
|
||||||
limit=10,
|
limit=10,
|
||||||
with_vector=True)
|
with_vector=True,
|
||||||
assert (len(result_search) == 2)
|
)
|
||||||
|
assert len(result_search) == 2
|
||||||
|
|
||||||
# # Retrieve data-points
|
# # Retrieve data-points
|
||||||
result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
||||||
assert any(
|
assert any(str(r.id) == TEST_UUID and r.payload["text"] == TEST_TEXT for r in result)
|
||||||
str(r.id) == TEST_UUID and r.payload['text'] == TEST_TEXT
|
assert any(str(r.id) == TEST_UUID_2 and r.payload["text"] == TEST_TEXT_2 for r in result)
|
||||||
for r in result
|
|
||||||
)
|
|
||||||
assert any(
|
|
||||||
str(r.id) == TEST_UUID_2 and r.payload['text'] == TEST_TEXT_2
|
|
||||||
for r in result
|
|
||||||
)
|
|
||||||
# Search multiple
|
# Search multiple
|
||||||
result_search_batch = await engine.batch_search(
|
result_search_batch = await engine.batch_search(
|
||||||
collection_name=TEST_COLLECTION_NAME,
|
collection_name=TEST_COLLECTION_NAME,
|
||||||
query_texts=[TEST_TEXT, TEST_TEXT_2],
|
query_texts=[TEST_TEXT, TEST_TEXT_2],
|
||||||
limit=10,
|
limit=10,
|
||||||
with_vectors=False
|
with_vectors=False,
|
||||||
)
|
)
|
||||||
assert (len(result_search_batch) == 2 and
|
assert len(result_search_batch) == 2 and all(len(batch) == 2 for batch in result_search_batch)
|
||||||
all(len(batch) == 2 for batch in result_search_batch))
|
|
||||||
|
|
||||||
# Delete datapoint from vector store
|
# Delete datapoint from vector store
|
||||||
await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
||||||
|
|
@ -163,6 +161,7 @@ async def vector_backend_api_test():
|
||||||
result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID])
|
result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID])
|
||||||
assert result_deleted == []
|
assert result_deleted == []
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""
|
"""
|
||||||
Example script demonstrating how to use Cognee with Amazon Neptune Analytics
|
Example script demonstrating how to use Cognee with Amazon Neptune Analytics
|
||||||
|
|
@ -22,7 +23,7 @@ async def main():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||||
graph_endpoint_url = "neptune-graph://" + os.getenv('GRAPH_ID', "")
|
graph_endpoint_url = "neptune-graph://" + os.getenv("GRAPH_ID", "")
|
||||||
|
|
||||||
# Configure Neptune Analytics as the graph & vector database provider
|
# Configure Neptune Analytics as the graph & vector database provider
|
||||||
cognee.config.set_graph_db_config(
|
cognee.config.set_graph_db_config(
|
||||||
|
|
@ -77,7 +78,9 @@ async def main():
|
||||||
|
|
||||||
# Now let's perform some searches
|
# Now let's perform some searches
|
||||||
# 1. Search for insights related to "Neptune Analytics"
|
# 1. Search for insights related to "Neptune Analytics"
|
||||||
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neptune Analytics")
|
insights_results = await cognee.search(
|
||||||
|
query_type=SearchType.INSIGHTS, query_text="Neptune Analytics"
|
||||||
|
)
|
||||||
print("\n========Insights about Neptune Analytics========:")
|
print("\n========Insights about Neptune Analytics========:")
|
||||||
for result in insights_results:
|
for result in insights_results:
|
||||||
print(f"- {result}")
|
print(f"- {result}")
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue