move aoss to driver

This commit is contained in:
prestonrasmussen 2025-09-05 13:02:20 -04:00
parent ce1ae30569
commit 8ba998b9bb
5 changed files with 159 additions and 127 deletions

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import copy
import logging
from abc import ABC, abstractmethod
@ -21,8 +22,12 @@ from collections.abc import Coroutine
from enum import Enum
from typing import Any
from opensearchpy import OpenSearch, helpers
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
class GraphProvider(Enum):
NEO4J = 'neo4j'
@ -31,6 +36,83 @@ class GraphProvider(Enum):
NEPTUNE = 'neptune'
aoss_indices = [
{
'index_name': 'node_name_and_summary',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'summary': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'community_name',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'episode_content',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'content': {'type': 'text'},
'source': {'type': 'text'},
'source_description': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {
'multi_match': {
'query': '',
'fields': ['content', 'source', 'source_description', 'group_id'],
}
},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'edge_name_and_fact',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'fact': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
]
class GraphDriverSession(ABC):
provider: GraphProvider
@ -61,6 +143,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
aoss_client: OpenSearch | None
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -87,3 +170,49 @@ class GraphDriver(ABC):
cloned._database = database
return cloned
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body'])
# Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60)
async def delete_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices:
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {'_index': name}
for p in index['body']['mappings']['properties']:
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
if failed > 0:
return success
else:
return 0
return 0

View file

@ -18,7 +18,9 @@ import logging
from collections.abc import Coroutine
from typing import Any
import boto3
from neo4j import AsyncGraphDatabase, EagerResult
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
@ -29,7 +31,15 @@ logger = logging.getLogger(__name__)
class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
def __init__(
self,
uri: str,
user: str | None,
password: str | None,
database: str = 'neo4j',
aoss_host: str | None = None,
aoss_port: int | None = None,
):
super().__init__()
self.client = AsyncGraphDatabase.driver(
uri=uri,
@ -37,6 +47,20 @@ class Neo4jDriver(GraphDriver):
)
self._database = database
self.aoss_client = None
if aoss_host and aoss_port:
session = boto3.Session()
self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth(
session.get_credentials(), session.region_name, 'aoss'
),
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs.
# If not populated, set the value to retain backwards compatibility
@ -61,6 +85,8 @@ class Neo4jDriver(GraphDriver):
return await self.client.close()
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
if self.aoss_client:
self.delete_all_indexes_impl()
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
)

View file

@ -24,86 +24,9 @@ import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
aoss_indices = [
{
'index_name': 'node_name_and_summary',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'summary': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'community_name',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'episode_content',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'content': {'type': 'text'},
'source': {'type': 'text'},
'source_description': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {
'multi_match': {
'query': '',
'fields': ['content', 'source', 'source_description', 'group_id'],
}
},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'edge_name_and_fact',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'fact': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
]
class NeptuneDriver(GraphDriver):
@ -223,52 +146,6 @@ class NeptuneDriver(GraphDriver):
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
return self.delete_all_indexes_impl()
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body'])
# Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60)
async def delete_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices:
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {'_index': name}
for p in index['body']['mappings']['properties']:
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
if failed > 0:
return success
else:
return 0
return 0
class NeptuneDriverSession(GraphDriverSession):
provider = GraphProvider.NEPTUNE

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.20.4"
version = "0.21.0"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.20.4"
version = "0.21.0"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },