diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 670a7426..6ece60bb 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import copy import logging from abc import ABC, abstractmethod @@ -21,8 +22,12 @@ from collections.abc import Coroutine from enum import Enum from typing import Any +from opensearchpy import OpenSearch, helpers + logger = logging.getLogger(__name__) +DEFAULT_SIZE = 10 + class GraphProvider(Enum): NEO4J = 'neo4j' @@ -31,6 +36,83 @@ class GraphProvider(Enum): NEPTUNE = 'neptune' +aoss_indices = [ + { + 'index_name': 'node_name_and_summary', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'name': {'type': 'text'}, + 'summary': {'type': 'text'}, + 'group_id': {'type': 'text'}, + } + } + }, + 'query': { + 'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}}, + 'size': DEFAULT_SIZE, + }, + }, + { + 'index_name': 'community_name', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'name': {'type': 'text'}, + 'group_id': {'type': 'text'}, + } + } + }, + 'query': { + 'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}}, + 'size': DEFAULT_SIZE, + }, + }, + { + 'index_name': 'episode_content', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'content': {'type': 'text'}, + 'source': {'type': 'text'}, + 'source_description': {'type': 'text'}, + 'group_id': {'type': 'text'}, + } + } + }, + 'query': { + 'query': { + 'multi_match': { + 'query': '', + 'fields': ['content', 'source', 'source_description', 'group_id'], + } + }, + 'size': DEFAULT_SIZE, + }, + }, + { + 'index_name': 'edge_name_and_fact', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'name': {'type': 'text'}, + 'fact': {'type': 'text'}, + 'group_id': {'type': 'text'}, + } + } + }, + 'query': { + 'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}}, + 'size': DEFAULT_SIZE, + }, + }, +] + + class GraphDriverSession(ABC): provider: GraphProvider @@ -61,6 +143,7 @@ class GraphDriver(ABC): '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) _database: str + aoss_client: OpenSearch | None @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: @@ -87,3 +170,49 @@ class GraphDriver(ABC): cloned._database = database return cloned + + async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]: + # No matter what happens above, always return True + return self.delete_aoss_indices() + + async def create_aoss_indices(self): + for index in aoss_indices: + index_name = index['index_name'] + client = self.aoss_client + if not client.indices.exists(index=index_name): + client.indices.create(index=index_name, body=index['body']) + # Sleep for 1 minute to let the index creation complete + await asyncio.sleep(60) + + async def delete_aoss_indices(self): + for index in aoss_indices: + index_name = index['index_name'] + client = self.aoss_client + if client.indices.exists(index=index_name): + client.indices.delete(index=index_name) + + def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]: + for index in aoss_indices: + if name.lower() == index['index_name']: + index['query']['query']['multi_match']['query'] = query_text + query = {'size': limit, 'query': index['query']} + resp = self.aoss_client.search(body=query['query'], index=index['index_name']) + return resp + return {} + + def save_to_aoss(self, name: str, data: list[dict]) -> int: + for index in aoss_indices: + if name.lower() == index['index_name']: + to_index = [] + for d in data: + item = {'_index': name} + for p in index['body']['mappings']['properties']: + item[p] = d[p] + to_index.append(item) + success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True) + if failed > 0: + return success + else: + return 0 + + return 0 diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 7ac9a5a8..c59eefb3 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -18,7 +18,9 @@ import logging from collections.abc import Coroutine from typing import Any +import boto3 from neo4j import AsyncGraphDatabase, EagerResult +from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider @@ -29,7 +31,15 @@ logger = logging.getLogger(__name__) class Neo4jDriver(GraphDriver): provider = GraphProvider.NEO4J - def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'): + def __init__( + self, + uri: str, + user: str | None, + password: str | None, + database: str = 'neo4j', + aoss_host: str | None = None, + aoss_port: int | None = None, + ): super().__init__() self.client = AsyncGraphDatabase.driver( uri=uri, @@ -37,6 +47,20 @@ class Neo4jDriver(GraphDriver): ) self._database = database + self.aoss_client = None + if aoss_host and aoss_port: + session = boto3.Session() + self.aoss_client = OpenSearch( + hosts=[{'host': aoss_host, 'port': aoss_port}], + http_auth=Urllib3AWSV4SignerAuth( + session.get_credentials(), session.region_name, 'aoss' + ), + use_ssl=True, + verify_certs=True, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: # Check if database_ is provided in kwargs. # If not populated, set the value to retain backwards compatibility @@ -61,6 +85,8 @@ class Neo4jDriver(GraphDriver): return await self.client.close() def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: + if self.aoss_client: + self.delete_all_indexes_impl() return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', ) diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index 25aa12c3..4d8dbbe3 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -24,86 +24,9 @@ import boto3 from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices logger = logging.getLogger(__name__) -DEFAULT_SIZE = 10 - -aoss_indices = [ - { - 'index_name': 'node_name_and_summary', - 'body': { - 'mappings': { - 'properties': { - 'uuid': {'type': 'keyword'}, - 'name': {'type': 'text'}, - 'summary': {'type': 'text'}, - 'group_id': {'type': 'text'}, - } - } - }, - 'query': { - 'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}}, - 'size': DEFAULT_SIZE, - }, - }, - { - 'index_name': 'community_name', - 'body': { - 'mappings': { - 'properties': { - 'uuid': {'type': 'keyword'}, - 'name': {'type': 'text'}, - 'group_id': {'type': 'text'}, - } - } - }, - 'query': { - 'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}}, - 'size': DEFAULT_SIZE, - }, - }, - { - 'index_name': 'episode_content', - 'body': { - 'mappings': { - 'properties': { - 'uuid': {'type': 'keyword'}, - 'content': {'type': 'text'}, - 'source': {'type': 'text'}, - 'source_description': {'type': 'text'}, - 'group_id': {'type': 'text'}, - } - } - }, - 'query': { - 'query': { - 'multi_match': { - 'query': '', - 'fields': ['content', 'source', 'source_description', 'group_id'], - } - }, - 'size': DEFAULT_SIZE, - }, - }, - { - 'index_name': 'edge_name_and_fact', - 'body': { - 'mappings': { - 'properties': { - 'uuid': {'type': 'keyword'}, - 'name': {'type': 'text'}, - 'fact': {'type': 'text'}, - 'group_id': {'type': 'text'}, - } - } - }, - 'query': { - 'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}}, - 'size': DEFAULT_SIZE, - }, - }, -] class NeptuneDriver(GraphDriver): @@ -223,52 +146,6 @@ class NeptuneDriver(GraphDriver): def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: return self.delete_all_indexes_impl() - async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]: - # No matter what happens above, always return True - return self.delete_aoss_indices() - - async def create_aoss_indices(self): - for index in aoss_indices: - index_name = index['index_name'] - client = self.aoss_client - if not client.indices.exists(index=index_name): - client.indices.create(index=index_name, body=index['body']) - # Sleep for 1 minute to let the index creation complete - await asyncio.sleep(60) - - async def delete_aoss_indices(self): - for index in aoss_indices: - index_name = index['index_name'] - client = self.aoss_client - if client.indices.exists(index=index_name): - client.indices.delete(index=index_name) - - def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]: - for index in aoss_indices: - if name.lower() == index['index_name']: - index['query']['query']['multi_match']['query'] = query_text - query = {'size': limit, 'query': index['query']} - resp = self.aoss_client.search(body=query['query'], index=index['index_name']) - return resp - return {} - - def save_to_aoss(self, name: str, data: list[dict]) -> int: - for index in aoss_indices: - if name.lower() == index['index_name']: - to_index = [] - for d in data: - item = {'_index': name} - for p in index['body']['mappings']['properties']: - item[p] = d[p] - to_index.append(item) - success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True) - if failed > 0: - return success - else: - return 0 - - return 0 - class NeptuneDriverSession(GraphDriverSession): provider = GraphProvider.NEPTUNE diff --git a/pyproject.toml b/pyproject.toml index 04061eec..715ecff6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.20.4" +version = "0.21.0" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index e2a3b855..6abdaaf9 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.20.4" +version = "0.21.0" source = { editable = "." } dependencies = [ { name = "diskcache" },