Add Neptune Analytics hybrid storage (#17)
* Add neptune analytics driver Signed-off-by: Andrew Carbonetto <andrew.carbonetto@improving.com>
This commit is contained in:
parent
9907e6fe5b
commit
3bf93060b7
13 changed files with 3000 additions and 0 deletions
|
|
@ -135,6 +135,56 @@ def create_graph_engine(
|
|||
graph_database_password=graph_database_password or None,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "neptune":
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
|
||||
)
|
||||
|
||||
if not graph_database_url:
|
||||
raise EnvironmentError("Missing Neptune endpoint.")
|
||||
|
||||
from .neptune_driver.adapter import NeptuneGraphDB, NEPTUNE_ENDPOINT_URL
|
||||
|
||||
if not graph_database_url.startswith(NEPTUNE_ENDPOINT_URL):
|
||||
raise ValueError(f"Neptune endpoint must have the format {NEPTUNE_ENDPOINT_URL}<GRAPH_ID>")
|
||||
|
||||
graph_identifier = graph_database_url.replace(NEPTUNE_ENDPOINT_URL, "")
|
||||
|
||||
return NeptuneGraphDB(
|
||||
graph_id=graph_identifier,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "neptune_analytics":
|
||||
"""
|
||||
Creates a graph DB from config
|
||||
We want to use a hybrid (graph & vector) DB and we should update this
|
||||
to make a single instance of the hybrid configuration (with embedder)
|
||||
instead of creating the hybrid object twice.
|
||||
"""
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
|
||||
)
|
||||
|
||||
if not graph_database_url:
|
||||
raise EnvironmentError("Missing Neptune endpoint.")
|
||||
|
||||
from ..hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
|
||||
|
||||
if not graph_database_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
||||
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
|
||||
|
||||
graph_identifier = graph_database_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
||||
|
||||
return NeptuneAnalyticsAdapter(
|
||||
graph_id=graph_identifier,
|
||||
)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
graph_client = NetworkXAdapter(filename=graph_file_path)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
"""Neptune Analytics Driver Module
|
||||
|
||||
This module provides the Neptune Analytics adapter and utilities for interacting
|
||||
with Amazon Neptune Analytics graph databases.
|
||||
"""
|
||||
|
||||
from .adapter import NeptuneGraphDB
|
||||
from . import neptune_utils
|
||||
from . import exceptions
|
||||
|
||||
__all__ = [
|
||||
"NeptuneGraphDB",
|
||||
"neptune_utils",
|
||||
"exceptions",
|
||||
]
|
||||
1445
cognee/infrastructure/databases/graph/neptune_driver/adapter.py
Normal file
1445
cognee/infrastructure/databases/graph/neptune_driver/adapter.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,49 @@
|
|||
"""Neptune Analytics Exceptions
|
||||
|
||||
This module defines custom exceptions for Neptune Analytics operations.
|
||||
"""
|
||||
|
||||
|
||||
class NeptuneAnalyticsError(Exception):
|
||||
"""Base exception for Neptune Analytics operations."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsConnectionError(NeptuneAnalyticsError):
|
||||
"""Exception raised when connection to Neptune Analytics fails."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsQueryError(NeptuneAnalyticsError):
|
||||
"""Exception raised when a query execution fails."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsAuthenticationError(NeptuneAnalyticsError):
|
||||
"""Exception raised when authentication with Neptune Analytics fails."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsConfigurationError(NeptuneAnalyticsError):
|
||||
"""Exception raised when Neptune Analytics configuration is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsTimeoutError(NeptuneAnalyticsError):
|
||||
"""Exception raised when a Neptune Analytics operation times out."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsThrottlingError(NeptuneAnalyticsError):
|
||||
"""Exception raised when requests are throttled by Neptune Analytics."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsResourceNotFoundError(NeptuneAnalyticsError):
|
||||
"""Exception raised when a Neptune Analytics resource is not found."""
|
||||
pass
|
||||
|
||||
|
||||
class NeptuneAnalyticsInvalidParameterError(NeptuneAnalyticsError):
|
||||
"""Exception raised when invalid parameters are provided to Neptune Analytics."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,221 @@
|
|||
"""Neptune Utilities
|
||||
|
||||
This module provides utility functions for Neptune Analytics operations including
|
||||
connection management, URL parsing, and Neptune-specific configurations.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("NeptuneUtils")
|
||||
|
||||
|
||||
def parse_neptune_url(url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse a Neptune Analytics URL to extract graph ID and region.
|
||||
|
||||
Expected format: neptune-graph://<GRAPH_ID>?region=<REGION>
|
||||
or neptune-graph://<GRAPH_ID> (defaults to us-east-1)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- url (str): The Neptune Analytics URL to parse
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Tuple[str, str]: A tuple containing (graph_id, region)
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- ValueError: If the URL format is invalid
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.scheme != "neptune-graph":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}. Expected 'neptune-graph'")
|
||||
|
||||
graph_id = parsed.hostname or parsed.path.lstrip('/')
|
||||
if not graph_id:
|
||||
raise ValueError("Graph ID not found in URL")
|
||||
|
||||
# Extract region from query parameters
|
||||
region = "us-east-1" # default region
|
||||
if parsed.query:
|
||||
query_params = dict(param.split('=') for param in parsed.query.split('&') if '=' in param)
|
||||
region = query_params.get('region', region)
|
||||
|
||||
return graph_id, region
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Neptune Analytics URL '{url}': {str(e)}")
|
||||
|
||||
|
||||
def validate_graph_id(graph_id: str) -> bool:
|
||||
"""
|
||||
Validate a Neptune Analytics graph ID format.
|
||||
|
||||
Graph IDs should follow AWS naming conventions.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- graph_id (str): The graph ID to validate
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- bool: True if the graph ID is valid, False otherwise
|
||||
"""
|
||||
if not graph_id:
|
||||
return False
|
||||
|
||||
# Neptune Analytics graph IDs should be alphanumeric with hyphens
|
||||
# and between 1-63 characters
|
||||
pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-]{0,62}$'
|
||||
return bool(re.match(pattern, graph_id))
|
||||
|
||||
|
||||
def validate_aws_region(region: str) -> bool:
|
||||
"""
|
||||
Validate an AWS region format.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- region (str): The AWS region to validate
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- bool: True if the region format is valid, False otherwise
|
||||
"""
|
||||
if not region:
|
||||
return False
|
||||
|
||||
# AWS regions follow the pattern: us-east-1, eu-west-1, etc.
|
||||
pattern = r'^[a-z]{2,3}-[a-z]+-\d+$'
|
||||
return bool(re.match(pattern, region))
|
||||
|
||||
|
||||
def build_neptune_config(
|
||||
graph_id: str,
|
||||
region: Optional[str],
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a configuration dictionary for Neptune Analytics connection.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- graph_id (str): The Neptune Analytics graph identifier
|
||||
- region (Optional[str]): AWS region where the graph is located
|
||||
- aws_access_key_id (Optional[str]): AWS access key ID
|
||||
- aws_secret_access_key (Optional[str]): AWS secret access key
|
||||
- aws_session_token (Optional[str]): AWS session token for temporary credentials
|
||||
- **kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Dict[str, Any]: Configuration dictionary for Neptune Analytics
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- ValueError: If required parameters are invalid
|
||||
"""
|
||||
config = {
|
||||
"graph_id": graph_id,
|
||||
"service_name": "neptune-graph",
|
||||
}
|
||||
|
||||
# Add AWS credentials if provided
|
||||
if region:
|
||||
config["region"] = region
|
||||
|
||||
if aws_access_key_id:
|
||||
config["aws_access_key_id"] = aws_access_key_id
|
||||
|
||||
if aws_secret_access_key:
|
||||
config["aws_secret_access_key"] = aws_secret_access_key
|
||||
|
||||
if aws_session_token:
|
||||
config["aws_session_token"] = aws_session_token
|
||||
|
||||
# Add any additional configuration
|
||||
config.update(kwargs)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_neptune_endpoint_url(graph_id: str, region: str) -> str:
|
||||
"""
|
||||
Construct the Neptune Analytics endpoint URL for a given graph and region.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- graph_id (str): The Neptune Analytics graph identifier
|
||||
- region (str): AWS region where the graph is located
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: The Neptune Analytics endpoint URL
|
||||
"""
|
||||
return f"https://neptune-graph.{region}.amazonaws.com/graphs/{graph_id}"
|
||||
|
||||
|
||||
def format_neptune_error(error: Exception) -> str:
|
||||
"""
|
||||
Format Neptune Analytics specific errors for better readability.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- error (Exception): The exception to format
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: Formatted error message
|
||||
"""
|
||||
error_msg = str(error)
|
||||
|
||||
# Common Neptune Analytics error patterns and their user-friendly messages
|
||||
error_mappings = {
|
||||
"AccessDenied": "Access denied. Please check your AWS credentials and permissions.",
|
||||
"GraphNotFound": "Graph not found. Please verify the graph ID and region.",
|
||||
"InvalidParameter": "Invalid parameter provided. Please check your request parameters.",
|
||||
"ThrottlingException": "Request was throttled. Please retry with exponential backoff.",
|
||||
"InternalServerError": "Internal server error occurred. Please try again later.",
|
||||
}
|
||||
|
||||
for error_type, friendly_msg in error_mappings.items():
|
||||
if error_type in error_msg:
|
||||
return f"{friendly_msg} Original error: {error_msg}"
|
||||
|
||||
return error_msg
|
||||
|
||||
def get_default_query_timeout() -> int:
|
||||
"""
|
||||
Get the default query timeout for Neptune Analytics operations.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- int: Default timeout in seconds
|
||||
"""
|
||||
return 300 # 5 minutes
|
||||
|
||||
|
||||
def get_default_connection_config() -> Dict[str, Any]:
|
||||
"""
|
||||
Get default connection configuration for Neptune Analytics.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Dict[str, Any]: Default connection configuration
|
||||
"""
|
||||
return {
|
||||
"query_timeout": get_default_query_timeout(),
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1.0,
|
||||
"preferred_query_language": "openCypher",
|
||||
}
|
||||
|
|
@ -0,0 +1,436 @@
|
|||
"""Neptune Analytics Hybrid Adapter combining Vector and Graph functionality"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Optional, Any, Dict, Type, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.graph.neptune_driver.adapter import NeptuneGraphDB
|
||||
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.vector.models.PayloadSchema import PayloadSchema
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
|
||||
logger = get_logger("NeptuneAnalyticsAdapter")
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
"""
|
||||
Represents a schema for an index data point containing an ID and text.
|
||||
|
||||
Attributes:
|
||||
- id: A string representing the unique identifier for the data point.
|
||||
- text: A string representing the content of the data point.
|
||||
- metadata: A dictionary with default index fields for the schema, currently configured
|
||||
to include 'text'.
|
||||
"""
|
||||
id: str
|
||||
text: str
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
NEPTUNE_ANALYTICS_ENDPOINT_URL = "neptune-graph://"
|
||||
|
||||
class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||
"""
|
||||
Hybrid adapter that combines Neptune Analytics Vector and Graph functionality.
|
||||
|
||||
This adapter extends NeptuneGraphDB and implements VectorDBInterface to provide
|
||||
a unified interface for working with Neptune Analytics as both a vector store
|
||||
and a graph database.
|
||||
"""
|
||||
|
||||
_VECTOR_NODE_LABEL = "COGNEE_NODE"
|
||||
_COLLECTION_PREFIX = "VECTOR_COLLECTION"
|
||||
_TOPK_LOWER_BOUND = 0
|
||||
_TOPK_UPPER_BOUND = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_id: str,
|
||||
embedding_engine: Optional[EmbeddingEngine] = None,
|
||||
region: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Neptune Analytics hybrid adapter.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- graph_id (str): The Neptune Analytics graph identifier
|
||||
- embedding_engine(Optional[EmbeddingEngine]): The embedding engine instance to translate text to vector.
|
||||
- region (Optional[str]): AWS region where the graph is located (default: us-east-1)
|
||||
- aws_access_key_id (Optional[str]): AWS access key ID
|
||||
- aws_secret_access_key (Optional[str]): AWS secret access key
|
||||
- aws_session_token (Optional[str]): AWS session token for temporary credentials
|
||||
"""
|
||||
# Initialize the graph database functionality
|
||||
super().__init__(
|
||||
graph_id=graph_id,
|
||||
region=region,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token
|
||||
)
|
||||
|
||||
# Add vector-specific attributes
|
||||
self.embedding_engine = embedding_engine
|
||||
logger.info(f"Initialized Neptune Analytics hybrid adapter for graph: \"{graph_id}\" in region: \"{self.region}\"")
|
||||
|
||||
# VectorDBInterface methods implementation
|
||||
|
||||
async def get_connection(self):
|
||||
"""
|
||||
This method is part of the default implementation but not defined in the interface.
|
||||
No operation is performed and None will be returned here,
|
||||
because the concept of connection is not applicable in this context.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Embeds the provided textual data into vector representation.
|
||||
|
||||
Uses the embedding engine to convert the list of strings into a list of float vectors.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- data (list[str]): A list of strings representing the data to be embedded.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[list[float]]: A list of embedded vectors corresponding to the input data.
|
||||
"""
|
||||
self._validate_embedding_engine()
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Neptune Analytics stores vector on a node level,
|
||||
so create_collection() implements interface for compliance but performs no operations when called.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- collection_name (str): The name of the collection to check for existence.
|
||||
Returns:
|
||||
--------
|
||||
- bool: Always return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload_schema: Optional[PayloadSchema] = None,
|
||||
):
|
||||
"""
|
||||
Neptune Analytics stores vector on a node level, so create_collection() implements interface for compliance but performs no operations when called.
|
||||
As the result, create_collection() will be no-op.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- collection_name (str): The name of the new collection to create.
|
||||
- payload_schema (Optional[PayloadSchema]): An optional schema for the payloads
|
||||
within this collection. (default None)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def get_collection(self, collection_name: str):
|
||||
"""
|
||||
This method is part of the default implementation but not defined in the interface.
|
||||
No operation is performed here because the concept of collection is not applicable in NeptuneAnalytics vector store.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
"""
|
||||
Insert new data points into the specified collection, by first inserting the node itself on the graph,
|
||||
then execute neptune.algo.vectors.upsert() to insert the corresponded embedding.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- collection_name (str): The name of the collection where data points will be added.
|
||||
- data_points (List[DataPoint]): A list of data points to be added to the
|
||||
collection.
|
||||
"""
|
||||
self._validate_embedding_engine()
|
||||
|
||||
# Fetch embeddings
|
||||
texts = [DataPoint.get_embeddable_data(t) for t in data_points]
|
||||
data_vectors = (await self.embedding_engine.embed_text(texts))
|
||||
|
||||
for index, data_point in enumerate(data_points):
|
||||
node_id = data_point.id
|
||||
# Fetch embedding from list instead
|
||||
data_vector = data_vectors[index]
|
||||
|
||||
# Fetch properties
|
||||
properties = self._serialize_properties(data_point.model_dump())
|
||||
properties[self._COLLECTION_PREFIX] = collection_name
|
||||
params = dict(
|
||||
node_id = str(node_id),
|
||||
properties = properties,
|
||||
embedding = data_vector,
|
||||
collection_name = collection_name
|
||||
)
|
||||
|
||||
# Compose the query and send
|
||||
query_string = (
|
||||
f"MERGE (n "
|
||||
f":{self._VECTOR_NODE_LABEL} "
|
||||
f" {{`~id`: $node_id}}) "
|
||||
f"ON CREATE SET n = $properties, n.updated_at = timestamp() "
|
||||
f"ON MATCH SET n += $properties, n.updated_at = timestamp() "
|
||||
f"WITH n, $embedding AS embedding "
|
||||
f"CALL neptune.algo.vectors.upsert(n, embedding) "
|
||||
f"YIELD success "
|
||||
f"RETURN success ")
|
||||
|
||||
try:
|
||||
self._client.query(query_string, params)
|
||||
except Exception as e:
|
||||
self._na_exception_handler(e, query_string)
|
||||
pass
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
"""
|
||||
Retrieve data points from a collection using their IDs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- collection_name (str): The name of the collection from which to retrieve data
|
||||
points.
|
||||
- data_point_ids (list[str]): A list of IDs of the data points to retrieve.
|
||||
"""
|
||||
# Do the fetch for each node
|
||||
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
||||
query_string = (f"MATCH( n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"WHERE id(n) in $node_ids AND "
|
||||
f"n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||
f"RETURN n as payload ")
|
||||
|
||||
try:
|
||||
result = self._client.query(query_string, params)
|
||||
return [self._get_scored_result(item) for item in result]
|
||||
except Exception as e:
|
||||
self._na_exception_handler(e, query_string)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = None,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform a search in the specified collection using either a text query or a vector
|
||||
query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- collection_name (str): The name of the collection in which to perform the search.
|
||||
- query_text (Optional[str]): An optional text query to search for in the
|
||||
collection.
|
||||
- query_vector (Optional[List[float]]): An optional vector representation for
|
||||
searching the collection.
|
||||
- limit (int): The maximum number of results to return from the search.
|
||||
- with_vector (bool): Whether to return the vector representations with search
|
||||
results, this is not supported for Neptune Analytics backend at the moment.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
A list of scored results that match the query.
|
||||
"""
|
||||
self._validate_embedding_engine()
|
||||
|
||||
if with_vector:
|
||||
logger.warning(
|
||||
"with_vector=True will include embedding vectors in the result. "
|
||||
"This may trigger a resource-intensive query and increase response time. "
|
||||
"Use this option only when vector data is required."
|
||||
)
|
||||
|
||||
# In the case of excessive limit, or zero / negative value, limit will be set to 10.
|
||||
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
||||
logger.warning(
|
||||
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
|
||||
"Defaulting to limit=10.", limit
|
||||
)
|
||||
limit = self._TOPK_UPPER_BOUND
|
||||
|
||||
if query_vector and query_text:
|
||||
raise InvalidValueError(
|
||||
message="The search function accepts either text or embedding as input, but not both."
|
||||
)
|
||||
elif query_text is None and query_vector is None:
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
elif query_vector:
|
||||
embedding = query_vector
|
||||
else:
|
||||
data_vectors = (await self.embedding_engine.embed_text([query_text]))
|
||||
embedding = data_vectors[0]
|
||||
|
||||
# Compose the parameters map
|
||||
params = dict(embedding=embedding, param_topk=limit)
|
||||
# Compose the query
|
||||
query_string = f"""
|
||||
CALL neptune.algo.vectors.topKByEmbeddingWithFiltering({{
|
||||
topK: {limit},
|
||||
embedding: {embedding},
|
||||
nodeFilter: {{ equals: {{property: '{self._COLLECTION_PREFIX}', value: '{collection_name}'}} }}
|
||||
}}
|
||||
)
|
||||
YIELD node, score
|
||||
"""
|
||||
|
||||
if with_vector:
|
||||
query_string += """
|
||||
WITH node, score, id(node) as node_id
|
||||
MATCH (n)
|
||||
WHERE id(n) = id(node)
|
||||
CALL neptune.algo.vectors.get(n)
|
||||
YIELD embedding
|
||||
RETURN node as payload, score, embedding
|
||||
"""
|
||||
|
||||
else:
|
||||
query_string += """
|
||||
RETURN node as payload, score
|
||||
"""
|
||||
|
||||
try:
|
||||
query_response = self._client.query(query_string, params)
|
||||
return [self._get_scored_result(
|
||||
item = item, with_score = True
|
||||
) for item in query_response]
|
||||
except Exception as e:
|
||||
self._na_exception_handler(e, query_string)
|
||||
|
||||
async def batch_search(
|
||||
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
||||
):
|
||||
"""
|
||||
Perform a batch search using multiple text queries against a collection.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- collection_name (str): The name of the collection to conduct the batch search in.
|
||||
- query_texts (List[str]): A list of text queries to use for the search.
|
||||
- limit (int): The maximum number of results to return for each query.
|
||||
- with_vectors (bool): Whether to include vector representations with search
|
||||
results. (default False)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
A list of search result sets, one for each query input.
|
||||
"""
|
||||
self._validate_embedding_engine()
|
||||
|
||||
# Convert text to embedding array in batch
|
||||
data_vectors = (await self.embedding_engine.embed_text(query_texts))
|
||||
return await asyncio.gather(*[
|
||||
self.search(collection_name, None, vector, limit, with_vectors)
|
||||
for vector in data_vectors
|
||||
])
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
"""
|
||||
Delete specified data points from a collection, by executing an OpenCypher query,
|
||||
with matching [vector_label, collection_label, node_id] combination.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- collection_name (str): The name of the collection from which to delete data
|
||||
points.
|
||||
- data_point_ids (list[str]): A list of IDs of the data points to delete.
|
||||
"""
|
||||
params = dict(node_ids=data_point_ids, collection_name=collection_name)
|
||||
query_string = (f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"WHERE id(n) IN $node_ids "
|
||||
f"AND n.{self._COLLECTION_PREFIX} = $collection_name "
|
||||
f"DETACH DELETE n")
|
||||
try:
|
||||
self._client.query(query_string, params)
|
||||
except Exception as e:
|
||||
self._na_exception_handler(e, query_string)
|
||||
pass
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
"""
|
||||
Neptune Analytics stores vectors at the node level,
|
||||
so create_vector_index() implements the interface for compliance but performs no operation when called.
|
||||
As a result, create_vector_index() invokes create_collection(), which is also a no-op.
|
||||
This ensures the logic flow remains consistent, even if the concept of collections is introduced in a future release.
|
||||
"""
|
||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
"""
|
||||
Indexes a list of data points into Neptune Analytics by creating them as nodes.
|
||||
|
||||
This method constructs a unique collection name by combining the `index_name` and
|
||||
`index_property_name`, then delegates to `create_data_points()` to store the data.
|
||||
|
||||
Args:
|
||||
index_name (str): The base name of the index.
|
||||
index_property_name (str): The property name to append to the index name for uniqueness.
|
||||
data_points (list[DataPoint]): A list of `DataPoint` instances to be indexed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
await self.create_data_points(
|
||||
f"{index_name}_{index_property_name}",
|
||||
[
|
||||
IndexSchema(
|
||||
id=str(data_point.id),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
],
|
||||
)
|
||||
|
||||
async def prune(self):
|
||||
"""
|
||||
Remove obsolete or unnecessary data from the database.
|
||||
"""
|
||||
# Run actual truncate
|
||||
self._client.query(f"MATCH (n :{self._VECTOR_NODE_LABEL}) "
|
||||
f"DETACH DELETE n")
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_scored_result(item: dict, with_vector: bool = False, with_score: bool = False) -> ScoredResult:
|
||||
"""
|
||||
Util method to simplify the object creation of ScoredResult base on incoming NX payload response.
|
||||
"""
|
||||
return ScoredResult(
|
||||
id=item.get('payload').get('~id'),
|
||||
payload=item.get('payload').get('~properties'),
|
||||
score=item.get('score') if with_score else 0,
|
||||
vector=item.get('embedding') if with_vector else None
|
||||
)
|
||||
|
||||
def _na_exception_handler(self, ex, query_string: str):
|
||||
"""
|
||||
Generic exception handler for NA langchain.
|
||||
"""
|
||||
logger.error(
|
||||
"Neptune Analytics query failed: %s | Query: [%s]", ex, query_string
|
||||
)
|
||||
raise ex
|
||||
|
||||
def _validate_embedding_engine(self):
|
||||
"""
|
||||
Validates if the embedding_engine is defined
|
||||
:raises: ValueError if this object does not have a valid embedding_engine
|
||||
"""
|
||||
if self.embedding_engine is None:
|
||||
raise ValueError("Neptune Analytics requires an embedder defined to make vector operations")
|
||||
|
|
@ -114,6 +114,28 @@ def create_vector_engine(
|
|||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "neptune_analytics":
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain_aws is not installed. Please install it with 'pip install langchain_aws'"
|
||||
)
|
||||
|
||||
if not vector_db_url:
|
||||
raise EnvironmentError("Missing Neptune endpoint.")
|
||||
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, NEPTUNE_ANALYTICS_ENDPOINT_URL
|
||||
if not vector_db_url.startswith(NEPTUNE_ANALYTICS_ENDPOINT_URL):
|
||||
raise ValueError(f"Neptune endpoint must have the format '{NEPTUNE_ANALYTICS_ENDPOINT_URL}<GRAPH_ID>'")
|
||||
|
||||
graph_identifier = vector_db_url.replace(NEPTUNE_ANALYTICS_ENDPOINT_URL, "")
|
||||
|
||||
return NeptuneAnalyticsAdapter(
|
||||
graph_id=graph_identifier,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
else:
|
||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
|
|
|
|||
313
cognee/tests/test_neptune_analytics_graph.py
Normal file
313
cognee/tests/test_neptune_analytics_graph.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
import os
|
||||
from dotenv import load_dotenv
|
||||
import asyncio
|
||||
from cognee.infrastructure.databases.graph.neptune_driver import NeptuneGraphDB
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
|
||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||
load_dotenv()
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
|
||||
na_adapter = NeptuneGraphDB(graph_id)
|
||||
|
||||
|
||||
def setup():
|
||||
# Define nodes data before the main function
|
||||
# These nodes were defined using openAI from the following prompt:
|
||||
|
||||
# Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads
|
||||
# that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It
|
||||
# complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load
|
||||
# the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's
|
||||
# stored in Amazon S3.
|
||||
|
||||
document = TextDocument(
|
||||
name='text_test.txt',
|
||||
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text_test.txt',
|
||||
external_metadata='{}',
|
||||
mime_type='text/plain'
|
||||
)
|
||||
document_chunk = DocumentChunk(
|
||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
chunk_size=187,
|
||||
chunk_index=0,
|
||||
cut_type='paragraph_end',
|
||||
is_part_of=document,
|
||||
)
|
||||
|
||||
graph_database = EntityType(name='graph database', description='graph database')
|
||||
neptune_analytics_entity = Entity(
|
||||
name='neptune analytics',
|
||||
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
|
||||
)
|
||||
neptune_database_entity = Entity(
|
||||
name='amazon neptune database',
|
||||
description='A popular managed graph database that complements Neptune Analytics.',
|
||||
)
|
||||
|
||||
storage = EntityType(name='storage', description='storage')
|
||||
storage_entity = Entity(
|
||||
name='amazon s3',
|
||||
description='A storage service provided by Amazon Web Services that allows storing graph data.',
|
||||
)
|
||||
|
||||
nodes_data = [
|
||||
document,
|
||||
document_chunk,
|
||||
graph_database,
|
||||
neptune_analytics_entity,
|
||||
neptune_database_entity,
|
||||
storage,
|
||||
storage_entity,
|
||||
]
|
||||
|
||||
edges_data = [
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(storage_entity.id),
|
||||
'contains',
|
||||
),
|
||||
(
|
||||
str(storage_entity.id),
|
||||
str(storage.id),
|
||||
'is_a',
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_database_entity.id),
|
||||
'contains',
|
||||
),
|
||||
(
|
||||
str(neptune_database_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(document.id),
|
||||
'is_part_of',
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_analytics_entity.id),
|
||||
'contains',
|
||||
),
|
||||
(
|
||||
str(neptune_analytics_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
),
|
||||
]
|
||||
|
||||
return nodes_data, edges_data
|
||||
|
||||
|
||||
async def pipeline_method():
|
||||
"""
|
||||
Example script using the neptune analytics with small sample data
|
||||
|
||||
This example demonstrates how to add nodes to Neptune Analytics
|
||||
"""
|
||||
|
||||
print("------TRUNCATE GRAPH-------")
|
||||
await na_adapter.delete_graph()
|
||||
|
||||
print("------SETUP DATA-------")
|
||||
nodes, edges = setup()
|
||||
|
||||
print("------ADD NODES-------")
|
||||
await na_adapter.add_node(nodes[0])
|
||||
await na_adapter.add_nodes(nodes[1:])
|
||||
|
||||
print("------GET NODES FROM DATA-------")
|
||||
node_ids = [str(node.id) for node in nodes]
|
||||
db_nodes = await na_adapter.get_nodes(node_ids)
|
||||
|
||||
print("------RESULTS:-------")
|
||||
for n in db_nodes:
|
||||
print(n)
|
||||
|
||||
print("------ADD EDGES-------")
|
||||
await na_adapter.add_edge(edges[0][0], edges[0][1], edges[0][2])
|
||||
await na_adapter.add_edges(edges[1:])
|
||||
|
||||
print("------HAS EDGES-------")
|
||||
has_edge = await na_adapter.has_edge(
|
||||
edges[0][0],
|
||||
edges[0][1],
|
||||
edges[0][2],
|
||||
)
|
||||
if has_edge:
|
||||
print(f"found edge ({edges[0][0]})-[{edges[0][2]}]->({edges[0][1]})")
|
||||
|
||||
has_edges = await na_adapter.has_edges(edges)
|
||||
if len(has_edges) > 0:
|
||||
print(f"found edges: {len(has_edges)} (expected: {len(edges)})")
|
||||
else:
|
||||
print(f"no edges found (expected: {len(edges)})")
|
||||
|
||||
print("------GET GRAPH-------")
|
||||
all_nodes, all_edges = await na_adapter.get_graph_data()
|
||||
print(f"found {len(all_nodes)} nodes and found {len(all_edges)} edges")
|
||||
|
||||
print("------NEIGHBORING NODES-------")
|
||||
center_node = nodes[2]
|
||||
neighbors = await na_adapter.get_neighbors(str(center_node.id))
|
||||
print(f"found {len(neighbors)} neighbors for node \"{center_node.name}\"")
|
||||
for neighbor in neighbors:
|
||||
print(neighbor)
|
||||
|
||||
print("------NEIGHBORING EDGES-------")
|
||||
center_node = nodes[2]
|
||||
neighbouring_edges = await na_adapter.get_edges(str(center_node.id))
|
||||
print(f"found {len(neighbouring_edges)} edges neighbouring node \"{center_node.name}\"")
|
||||
for edge in neighbouring_edges:
|
||||
print(edge)
|
||||
|
||||
print("------GET CONNECTIONS (SOURCE NODE)-------")
|
||||
document_chunk_node = nodes[0]
|
||||
connections = await na_adapter.get_connections(str(document_chunk_node.id))
|
||||
print(f"found {len(connections)} connections for node \"{document_chunk_node.type}\"")
|
||||
for connection in connections:
|
||||
src, relationship, tgt = connection
|
||||
src = src.get("name", src.get("type", "unknown"))
|
||||
relationship = relationship["relationship_name"]
|
||||
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
||||
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
|
||||
|
||||
print("------GET CONNECTIONS (TARGET NODE)-------")
|
||||
connections = await na_adapter.get_connections(str(center_node.id))
|
||||
print(f"found {len(connections)} connections for node \"{center_node.name}\"")
|
||||
for connection in connections:
|
||||
src, relationship, tgt = connection
|
||||
src = src.get("name", src.get("type", "unknown"))
|
||||
relationship = relationship["relationship_name"]
|
||||
tgt = tgt.get("name", tgt.get("type", "unknown"))
|
||||
print(f"\"{src}\"-[{relationship}]->\"{tgt}\"")
|
||||
|
||||
print("------SUBGRAPH-------")
|
||||
node_names = ["neptune analytics", "amazon neptune database"]
|
||||
subgraph_nodes, subgraph_edges = await na_adapter.get_nodeset_subgraph(Entity, node_names)
|
||||
print(f"found {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges in the subgraph around {node_names}")
|
||||
for subgraph_node in subgraph_nodes:
|
||||
print(subgraph_node)
|
||||
for subgraph_edge in subgraph_edges:
|
||||
print(subgraph_edge)
|
||||
|
||||
print("------STAT-------")
|
||||
stat = await na_adapter.get_graph_metrics(include_optional=True)
|
||||
assert type(stat) is dict
|
||||
assert stat['num_nodes'] == 7
|
||||
assert stat['num_edges'] == 7
|
||||
assert stat['mean_degree'] == 2.0
|
||||
assert round(stat['edge_density'], 3) == 0.167
|
||||
assert stat['num_connected_components'] == [7]
|
||||
assert stat['sizes_of_connected_components'] == 1
|
||||
assert stat['num_selfloops'] == 0
|
||||
# Unsupported optional metrics
|
||||
assert stat['diameter'] == -1
|
||||
assert stat['avg_shortest_path_length'] == -1
|
||||
assert stat['avg_clustering'] == -1
|
||||
|
||||
print("------DELETE-------")
|
||||
# delete all nodes and edges:
|
||||
await na_adapter.delete_graph()
|
||||
|
||||
# delete all nodes by node id
|
||||
# node_ids = [str(node.id) for node in nodes]
|
||||
# await na_adapter.delete_nodes(node_ids)
|
||||
|
||||
has_edges = await na_adapter.has_edges(edges)
|
||||
if len(has_edges) == 0:
|
||||
print(f"Delete successful")
|
||||
else:
|
||||
print(f"Delete failed")
|
||||
|
||||
|
||||
async def misc_methods():
|
||||
print("------TRUNCATE GRAPH-------")
|
||||
await na_adapter.delete_graph()
|
||||
|
||||
print("------SETUP TEST ENV-------")
|
||||
nodes, edges = setup()
|
||||
await na_adapter.add_nodes(nodes)
|
||||
await na_adapter.add_edges(edges)
|
||||
|
||||
print("------GET GRAPH-------")
|
||||
all_nodes, all_edges = await na_adapter.get_graph_data()
|
||||
print(f"found {len(all_nodes)} nodes and found {len(all_edges)} edges")
|
||||
|
||||
print("------GET DISCONNECTED-------")
|
||||
nodes_disconnected = await na_adapter.get_disconnected_nodes()
|
||||
print(nodes_disconnected)
|
||||
assert len(nodes_disconnected) == 0
|
||||
|
||||
print("------Get Labels (Node)-------")
|
||||
node_labels = await na_adapter.get_node_labels_string()
|
||||
print(node_labels)
|
||||
|
||||
print("------Get Labels (Edge)-------")
|
||||
edge_labels = await na_adapter.get_relationship_labels_string()
|
||||
print(edge_labels)
|
||||
|
||||
print("------Get Filtered Graph-------")
|
||||
filtered_nodes, filtered_edges = await na_adapter.get_filtered_graph_data([{'name': ['text_test.txt']}])
|
||||
print(filtered_nodes, filtered_edges)
|
||||
|
||||
print("------Get Degree one nodes-------")
|
||||
degree_one_nodes = await na_adapter.get_degree_one_nodes("EntityType")
|
||||
print(degree_one_nodes)
|
||||
|
||||
print("------Get Doc sub-graph-------")
|
||||
doc_sub_graph = await na_adapter.get_document_subgraph('test.txt')
|
||||
print(doc_sub_graph)
|
||||
|
||||
print("------Fetch and Remove connections (Predecessors)-------")
|
||||
# Fetch test edge
|
||||
(src_id, dest_id, relationship) = edges[0]
|
||||
nodes_predecessors = await na_adapter.get_predecessors(
|
||||
node_id=dest_id, edge_label=relationship
|
||||
)
|
||||
assert len(nodes_predecessors) > 0
|
||||
|
||||
await na_adapter.remove_connection_to_predecessors_of(
|
||||
node_ids=[src_id], edge_label=relationship
|
||||
)
|
||||
nodes_predecessors_after = await na_adapter.get_predecessors(
|
||||
node_id=dest_id, edge_label=relationship
|
||||
)
|
||||
# Return empty after relationship being deleted.
|
||||
assert len(nodes_predecessors_after) == 0
|
||||
|
||||
|
||||
print("------Fetch and Remove connections (Successors)-------")
|
||||
_, edges_suc = await na_adapter.get_graph_data()
|
||||
(src_id, dest_id, relationship, _) = edges_suc[0]
|
||||
|
||||
nodes_successors = await na_adapter.get_successors(
|
||||
node_id=src_id, edge_label=relationship
|
||||
)
|
||||
assert len(nodes_successors) > 0
|
||||
|
||||
await na_adapter.remove_connection_to_successors_of(
|
||||
node_ids=[dest_id], edge_label=relationship
|
||||
)
|
||||
nodes_successors_after = await na_adapter.get_successors(
|
||||
node_id=src_id, edge_label=relationship
|
||||
)
|
||||
assert len(nodes_successors_after) == 0
|
||||
|
||||
|
||||
# no-op
|
||||
await na_adapter.project_entire_graph()
|
||||
await na_adapter.drop_graph()
|
||||
await na_adapter.graph_exists()
|
||||
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(pipeline_method())
|
||||
asyncio.run(misc_methods())
|
||||
169
cognee/tests/test_neptune_analytics_hybrid.py
Normal file
169
cognee/tests/test_neptune_analytics_hybrid.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
import os
|
||||
from dotenv import load_dotenv
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter
|
||||
|
||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||
load_dotenv()
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
|
||||
# get the default embedder
|
||||
embedding_engine = get_embedding_engine()
|
||||
na_graph = NeptuneAnalyticsAdapter(graph_id)
|
||||
na_vector = NeptuneAnalyticsAdapter(graph_id, embedding_engine)
|
||||
|
||||
collection = "test_collection"
|
||||
|
||||
logger = get_logger("test_neptune_analytics_hybrid")
|
||||
|
||||
def setup_data():
|
||||
# Define nodes data before the main function
|
||||
# These nodes were defined using openAI from the following prompt:
|
||||
#
|
||||
# Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads
|
||||
# that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It
|
||||
# complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load
|
||||
# the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's
|
||||
# stored in Amazon S3.
|
||||
|
||||
document = TextDocument(
|
||||
name='text.txt',
|
||||
raw_data_location='git/cognee/examples/database_examples/data_storage/data/text.txt',
|
||||
external_metadata='{}',
|
||||
mime_type='text/plain'
|
||||
)
|
||||
document_chunk = DocumentChunk(
|
||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
chunk_size=187,
|
||||
chunk_index=0,
|
||||
cut_type='paragraph_end',
|
||||
is_part_of=document,
|
||||
)
|
||||
|
||||
graph_database = EntityType(name='graph database', description='graph database')
|
||||
neptune_analytics_entity = Entity(
|
||||
name='neptune analytics',
|
||||
description='A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.',
|
||||
)
|
||||
neptune_database_entity = Entity(
|
||||
name='amazon neptune database',
|
||||
description='A popular managed graph database that complements Neptune Analytics.',
|
||||
)
|
||||
|
||||
storage = EntityType(name='storage', description='storage')
|
||||
storage_entity = Entity(
|
||||
name='amazon s3',
|
||||
description='A storage service provided by Amazon Web Services that allows storing graph data.',
|
||||
)
|
||||
|
||||
nodes_data = [
|
||||
document,
|
||||
document_chunk,
|
||||
graph_database,
|
||||
neptune_analytics_entity,
|
||||
neptune_database_entity,
|
||||
storage,
|
||||
storage_entity,
|
||||
]
|
||||
|
||||
edges_data = [
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(storage_entity.id),
|
||||
'contains',
|
||||
),
|
||||
(
|
||||
str(storage_entity.id),
|
||||
str(storage.id),
|
||||
'is_a',
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_database_entity.id),
|
||||
'contains',
|
||||
),
|
||||
(
|
||||
str(neptune_database_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(document.id),
|
||||
'is_part_of',
|
||||
),
|
||||
(
|
||||
str(document_chunk.id),
|
||||
str(neptune_analytics_entity.id),
|
||||
'contains',
|
||||
),
|
||||
(
|
||||
str(neptune_analytics_entity.id),
|
||||
str(graph_database.id),
|
||||
'is_a',
|
||||
),
|
||||
]
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def test_add_graph_then_vector_data():
|
||||
logger.info("------test_add_graph_then_vector_data-------")
|
||||
(nodes, edges) = setup_data()
|
||||
await na_graph.add_nodes(nodes)
|
||||
await na_graph.add_edges(edges)
|
||||
await na_vector.create_data_points(collection, nodes)
|
||||
|
||||
node_ids = [str(node.id) for node in nodes]
|
||||
retrieved_data_points = await na_vector.retrieve(collection, node_ids)
|
||||
retrieved_nodes = await na_graph.get_nodes(node_ids)
|
||||
|
||||
assert len(retrieved_data_points) == len(retrieved_nodes) == len(node_ids)
|
||||
|
||||
# delete all nodes and edges and vectors:
|
||||
await na_graph.delete_graph()
|
||||
await na_vector.prune()
|
||||
|
||||
(nodes, edges) = await na_graph.get_graph_data()
|
||||
assert len(nodes) == 0
|
||||
assert len(edges) == 0
|
||||
logger.info("------PASSED-------")
|
||||
|
||||
async def test_add_vector_then_node_data():
|
||||
logger.info("------test_add_vector_then_node_data-------")
|
||||
(nodes, edges) = setup_data()
|
||||
await na_vector.create_data_points(collection, nodes)
|
||||
await na_graph.add_nodes(nodes)
|
||||
await na_graph.add_edges(edges)
|
||||
|
||||
node_ids = [str(node.id) for node in nodes]
|
||||
retrieved_data_points = await na_vector.retrieve(collection, node_ids)
|
||||
retrieved_nodes = await na_graph.get_nodes(node_ids)
|
||||
|
||||
assert len(retrieved_data_points) == len(retrieved_nodes) == len(node_ids)
|
||||
|
||||
# delete all nodes and edges and vectors:
|
||||
await na_vector.prune()
|
||||
await na_graph.delete_graph()
|
||||
|
||||
(nodes, edges) = await na_graph.get_graph_data()
|
||||
assert len(nodes) == 0
|
||||
assert len(edges) == 0
|
||||
logger.info("------PASSED-------")
|
||||
|
||||
def main():
|
||||
"""
|
||||
Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data
|
||||
This example demonstrates how to add nodes and vectors to Neptune Analytics, and ensures that
|
||||
the nodes do not conflict
|
||||
"""
|
||||
asyncio.run(test_add_graph_then_vector_data())
|
||||
asyncio.run(test_add_vector_then_node_data())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
172
cognee/tests/test_neptune_analytics_vector.py
Normal file
172
cognee/tests/test_neptune_analytics_vector.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
import uuid
|
||||
import pytest
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import NeptuneAnalyticsAdapter, IndexSchema
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
cognee.config.set_vector_db_provider("neptune_analytics")
|
||||
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_neptune")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_neptune")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
dataset_name = "cs_explanations"
|
||||
|
||||
explanation_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
)
|
||||
await cognee.add([explanation_file_path], dataset_name)
|
||||
|
||||
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
||||
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
|
||||
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
|
||||
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
|
||||
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
|
||||
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
|
||||
"""
|
||||
|
||||
await cognee.add([text], dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.INSIGHTS, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\nExtracted summaries are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
user = await get_default_user()
|
||||
history = await get_history(user.id)
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
async def vector_backend_api_test():
|
||||
cognee.config.set_vector_db_provider("neptune_analytics")
|
||||
|
||||
# When URL is absent
|
||||
cognee.config.set_vector_db_url(None)
|
||||
with pytest.raises(OSError):
|
||||
get_vector_engine()
|
||||
|
||||
# Assert invalid graph ID.
|
||||
cognee.config.set_vector_db_url("invalid_url")
|
||||
with pytest.raises(ValueError):
|
||||
get_vector_engine()
|
||||
|
||||
# Return a valid engine object with valid URL.
|
||||
graph_id = os.getenv('GRAPH_ID', "")
|
||||
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
||||
engine = get_vector_engine()
|
||||
assert isinstance(engine, NeptuneAnalyticsAdapter)
|
||||
|
||||
TEST_COLLECTION_NAME = "test"
|
||||
# Data point - 1
|
||||
TEST_UUID = str(uuid.uuid4())
|
||||
TEST_TEXT = "Hello world"
|
||||
datapoint = IndexSchema(id=TEST_UUID, text=TEST_TEXT)
|
||||
# Data point - 2
|
||||
TEST_UUID_2 = str(uuid.uuid4())
|
||||
TEST_TEXT_2 = "Cognee"
|
||||
datapoint_2 = IndexSchema(id=TEST_UUID_2, text=TEST_TEXT_2)
|
||||
|
||||
# Prun all vector_db entries
|
||||
await engine.prune()
|
||||
|
||||
# Always return true
|
||||
has_collection = await engine.has_collection(TEST_COLLECTION_NAME)
|
||||
assert has_collection
|
||||
# No-op
|
||||
await engine.create_collection(TEST_COLLECTION_NAME, IndexSchema)
|
||||
|
||||
# Save data-points
|
||||
await engine.create_data_points(TEST_COLLECTION_NAME, [datapoint, datapoint_2])
|
||||
# Search single text
|
||||
result_search = await engine.search(
|
||||
collection_name=TEST_COLLECTION_NAME,
|
||||
query_text=TEST_TEXT,
|
||||
query_vector=None,
|
||||
limit=10,
|
||||
with_vector=True)
|
||||
assert (len(result_search) == 2)
|
||||
|
||||
# # Retrieve data-points
|
||||
result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
||||
assert any(
|
||||
str(r.id) == TEST_UUID and r.payload['text'] == TEST_TEXT
|
||||
for r in result
|
||||
)
|
||||
assert any(
|
||||
str(r.id) == TEST_UUID_2 and r.payload['text'] == TEST_TEXT_2
|
||||
for r in result
|
||||
)
|
||||
# Search multiple
|
||||
result_search_batch = await engine.batch_search(
|
||||
collection_name=TEST_COLLECTION_NAME,
|
||||
query_texts=[TEST_TEXT, TEST_TEXT_2],
|
||||
limit=10,
|
||||
with_vectors=False
|
||||
)
|
||||
assert (len(result_search_batch) == 2 and
|
||||
all(len(batch) == 2 for batch in result_search_batch))
|
||||
|
||||
# Delete datapoint from vector store
|
||||
await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
||||
|
||||
# Retrieve should return an empty list.
|
||||
result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID])
|
||||
assert result_deleted == []
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(vector_backend_api_test())
|
||||
107
examples/database_examples/neptune_analytics_example.py
Normal file
107
examples/database_examples/neptune_analytics_example.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Example script demonstrating how to use Cognee with Amazon Neptune Analytics
|
||||
|
||||
This example:
|
||||
1. Configures Cognee to use Neptune Analytics as graph database
|
||||
2. Sets up data directories
|
||||
3. Adds sample data to Cognee
|
||||
4. Processes/cognifies the data
|
||||
5. Performs different types of searches
|
||||
"""
|
||||
|
||||
# Set up Amazon credentials in .env file and get the values from environment variables
|
||||
graph_endpoint_url = "neptune-graph://" + os.getenv('GRAPH_ID', "")
|
||||
|
||||
# Configure Neptune Analytics as the graph & vector database provider
|
||||
cognee.config.set_graph_db_config(
|
||||
{
|
||||
"graph_database_provider": "neptune_analytics", # Specify Neptune Analytics as provider
|
||||
"graph_database_url": graph_endpoint_url, # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>
|
||||
}
|
||||
)
|
||||
cognee.config.set_vector_db_config(
|
||||
{
|
||||
"vector_db_provider": "neptune_analytics", # Specify Neptune Analytics as provider
|
||||
"vector_db_url": graph_endpoint_url, # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>
|
||||
}
|
||||
)
|
||||
|
||||
# Set up data directories for storing documents and system files
|
||||
# You should adjust these paths to your needs
|
||||
current_dir = pathlib.Path(__file__).parent
|
||||
data_directory_path = str(current_dir / "data_storage")
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
cognee_directory_path = str(current_dir / "cognee_system")
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
# Clean any existing data (optional)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Create a dataset
|
||||
dataset_name = "neptune_example"
|
||||
|
||||
# Add sample text to the dataset
|
||||
sample_text_1 = """Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune
|
||||
Analytics, you can get insights and find trends by processing large amounts of graph data in seconds. To analyze
|
||||
graph data quickly and easily, Neptune Analytics stores large graph datasets in memory. It supports a library of
|
||||
optimized graph analytic algorithms, low-latency graph queries, and vector search capabilities within graph
|
||||
traversals.
|
||||
"""
|
||||
|
||||
sample_text_2 = """Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads
|
||||
that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It
|
||||
complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load
|
||||
the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's
|
||||
stored in Amazon S3.
|
||||
"""
|
||||
|
||||
# Add the sample text to the dataset
|
||||
await cognee.add([sample_text_1, sample_text_2], dataset_name)
|
||||
|
||||
# Process the added document to extract knowledge
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
# Now let's perform some searches
|
||||
# 1. Search for insights related to "Neptune Analytics"
|
||||
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neptune Analytics")
|
||||
print("\n========Insights about Neptune Analytics========:")
|
||||
for result in insights_results:
|
||||
print(f"- {result}")
|
||||
|
||||
# 2. Search for text chunks related to "graph database"
|
||||
chunks_results = await cognee.search(
|
||||
query_type=SearchType.CHUNKS, query_text="graph database", datasets=[dataset_name]
|
||||
)
|
||||
print("\n========Chunks about graph database========:")
|
||||
for result in chunks_results:
|
||||
print(f"- {result}")
|
||||
|
||||
# 3. Get graph completion related to databases
|
||||
graph_completion_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="database"
|
||||
)
|
||||
print("\n========Graph completion for databases========:")
|
||||
for result in graph_completion_results:
|
||||
print(f"- {result}")
|
||||
|
||||
# Clean up (optional)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -73,6 +73,7 @@ distributed = [
|
|||
|
||||
qdrant = ["qdrant-client>=1.14.2,<2"]
|
||||
neo4j = ["neo4j>=5.28.0,<6"]
|
||||
neptune = ["langchain_aws>=0.2.22"]
|
||||
postgres = [
|
||||
"psycopg2>=2.9.10,<3",
|
||||
"pgvector>=0.3.5,<0.4",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue