move aoss to driver
This commit is contained in:
parent
ce1ae30569
commit
8ba998b9bb
5 changed files with 159 additions and 127 deletions
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -21,8 +22,12 @@ from collections.abc import Coroutine
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
|
||||
class GraphProvider(Enum):
|
||||
NEO4J = 'neo4j'
|
||||
|
|
@ -31,6 +36,83 @@ class GraphProvider(Enum):
|
|||
NEPTUNE = 'neptune'
|
||||
|
||||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'summary': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'community_name',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'episode_content',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'content': {'type': 'text'},
|
||||
'source': {'type': 'text'},
|
||||
'source_description': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {
|
||||
'multi_match': {
|
||||
'query': '',
|
||||
'fields': ['content', 'source', 'source_description', 'group_id'],
|
||||
}
|
||||
},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'edge_name_and_fact',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'fact': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
provider: GraphProvider
|
||||
|
||||
|
|
@ -61,6 +143,7 @@ class GraphDriver(ABC):
|
|||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
_database: str
|
||||
aoss_client: OpenSearch | None
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
|
|
@ -87,3 +170,49 @@ class GraphDriver(ABC):
|
|||
cloned._database = database
|
||||
|
||||
return cloned
|
||||
|
||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||
# No matter what happens above, always return True
|
||||
return self.delete_aoss_indices()
|
||||
|
||||
async def create_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if not client.indices.exists(index=index_name):
|
||||
client.indices.create(index=index_name, body=index['body'])
|
||||
# Sleep for 1 minute to let the index creation complete
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def delete_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
index['query']['query']['multi_match']['query'] = query_text
|
||||
query = {'size': limit, 'query': index['query']}
|
||||
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||
return resp
|
||||
return {}
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {'_index': name}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||
if failed > 0:
|
||||
return success
|
||||
else:
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ import logging
|
|||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from neo4j import AsyncGraphDatabase, EagerResult
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
|
@ -29,7 +31,15 @@ logger = logging.getLogger(__name__)
|
|||
class Neo4jDriver(GraphDriver):
|
||||
provider = GraphProvider.NEO4J
|
||||
|
||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str | None,
|
||||
password: str | None,
|
||||
database: str = 'neo4j',
|
||||
aoss_host: str | None = None,
|
||||
aoss_port: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
|
|
@ -37,6 +47,20 @@ class Neo4jDriver(GraphDriver):
|
|||
)
|
||||
self._database = database
|
||||
|
||||
self.aoss_client = None
|
||||
if aoss_host and aoss_port:
|
||||
session = boto3.Session()
|
||||
self.aoss_client = OpenSearch(
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
http_auth=Urllib3AWSV4SignerAuth(
|
||||
session.get_credentials(), session.region_name, 'aoss'
|
||||
),
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
# Check if database_ is provided in kwargs.
|
||||
# If not populated, set the value to retain backwards compatibility
|
||||
|
|
@ -61,6 +85,8 @@ class Neo4jDriver(GraphDriver):
|
|||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
||||
if self.aoss_client:
|
||||
self.delete_all_indexes_impl()
|
||||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,86 +24,9 @@ import boto3
|
|||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'summary': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'community_name',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'episode_content',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'content': {'type': 'text'},
|
||||
'source': {'type': 'text'},
|
||||
'source_description': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {
|
||||
'multi_match': {
|
||||
'query': '',
|
||||
'fields': ['content', 'source', 'source_description', 'group_id'],
|
||||
}
|
||||
},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'edge_name_and_fact',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'fact': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NeptuneDriver(GraphDriver):
|
||||
|
|
@ -223,52 +146,6 @@ class NeptuneDriver(GraphDriver):
|
|||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
return self.delete_all_indexes_impl()
|
||||
|
||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||
# No matter what happens above, always return True
|
||||
return self.delete_aoss_indices()
|
||||
|
||||
async def create_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if not client.indices.exists(index=index_name):
|
||||
client.indices.create(index=index_name, body=index['body'])
|
||||
# Sleep for 1 minute to let the index creation complete
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def delete_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
index['query']['query']['multi_match']['query'] = query_text
|
||||
query = {'size': limit, 'query': index['query']}
|
||||
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||
return resp
|
||||
return {}
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {'_index': name}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||
if failed > 0:
|
||||
return success
|
||||
else:
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class NeptuneDriverSession(GraphDriverSession):
|
||||
provider = GraphProvider.NEPTUNE
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.20.4"
|
||||
version = "0.21.0"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.20.4"
|
||||
version = "0.21.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue