diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index a77cc5fc..e3448682 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -81,7 +81,13 @@ class IsPresidentOf(BaseModel): async def main(use_bulk: bool = False): setup_logging() graph_driver = Neo4jDriver( - neo4j_uri, neo4j_user, neo4j_password, aoss_host=aoss_host, aoss_port=aoss_port + neo4j_uri, + neo4j_user, + neo4j_password, + aoss_host=aoss_host, + aoss_port=aoss_port, + region='us-west-2', + service='es', ) # client = Graphiti( # neo4j_uri, diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index efeadcc3..c8be811d 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -28,7 +28,13 @@ logger = logging.getLogger(__name__) try: import boto3 - from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection + from opensearchpy import ( + AWSV4SignerAuth, + OpenSearch, + RequestsHttpConnection, + Urllib3AWSV4SignerAuth, + Urllib3HttpConnection, + ) _HAS_OPENSEARCH = True except ImportError: @@ -50,6 +56,8 @@ class Neo4jDriver(GraphDriver): database: str = 'neo4j', aoss_host: str | None = None, aoss_port: int | None = None, + region: str | None = None, + service: str | None = None, ): super().__init__() self.client = AsyncGraphDatabase.driver( @@ -61,15 +69,17 @@ class Neo4jDriver(GraphDriver): self.aoss_client = None if aoss_host and aoss_port and boto3 is not None: try: - session = boto3.Session() - self.aoss_client = OpenSearch( # type: ignore - hosts=[{'host': aoss_host, 'port': aoss_port, 'scheme': 'https'}], - http_auth=Urllib3AWSV4SignerAuth( - session.get_credentials(), session.region_name, 'aoss' - ), + region = region + service = service + credentials = boto3.Session(profile_name='zep-development').get_credentials() + auth = AWSV4SignerAuth(credentials, region, service) + + self.aoss_client = OpenSearch( + hosts=[{'host': aoss_host, 'port': aoss_port}], + http_auth=auth, use_ssl=True, verify_certs=True, - connection_class=Urllib3HttpConnection, + connection_class=RequestsHttpConnection, pool_maxsize=20, ) # type: ignore except Exception as e: