move aoss to driver
This commit is contained in:
parent
ce1ae30569
commit
8ba998b9bb
5 changed files with 159 additions and 127 deletions
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -21,8 +22,12 @@ from collections.abc import Coroutine
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from opensearchpy import OpenSearch, helpers
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SIZE = 10
|
||||||
|
|
||||||
|
|
||||||
class GraphProvider(Enum):
|
class GraphProvider(Enum):
|
||||||
NEO4J = 'neo4j'
|
NEO4J = 'neo4j'
|
||||||
|
|
@ -31,6 +36,83 @@ class GraphProvider(Enum):
|
||||||
NEPTUNE = 'neptune'
|
NEPTUNE = 'neptune'
|
||||||
|
|
||||||
|
|
||||||
|
aoss_indices = [
|
||||||
|
{
|
||||||
|
'index_name': 'node_name_and_summary',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'summary': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'community_name',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'episode_content',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'content': {'type': 'text'},
|
||||||
|
'source': {'type': 'text'},
|
||||||
|
'source_description': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {
|
||||||
|
'multi_match': {
|
||||||
|
'query': '',
|
||||||
|
'fields': ['content', 'source', 'source_description', 'group_id'],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'edge_name_and_fact',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'fact': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class GraphDriverSession(ABC):
|
class GraphDriverSession(ABC):
|
||||||
provider: GraphProvider
|
provider: GraphProvider
|
||||||
|
|
||||||
|
|
@ -61,6 +143,7 @@ class GraphDriver(ABC):
|
||||||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||||
)
|
)
|
||||||
_database: str
|
_database: str
|
||||||
|
aoss_client: OpenSearch | None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||||
|
|
@ -87,3 +170,49 @@ class GraphDriver(ABC):
|
||||||
cloned._database = database
|
cloned._database = database
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
|
|
||||||
|
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||||
|
# No matter what happens above, always return True
|
||||||
|
return self.delete_aoss_indices()
|
||||||
|
|
||||||
|
async def create_aoss_indices(self):
|
||||||
|
for index in aoss_indices:
|
||||||
|
index_name = index['index_name']
|
||||||
|
client = self.aoss_client
|
||||||
|
if not client.indices.exists(index=index_name):
|
||||||
|
client.indices.create(index=index_name, body=index['body'])
|
||||||
|
# Sleep for 1 minute to let the index creation complete
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
async def delete_aoss_indices(self):
|
||||||
|
for index in aoss_indices:
|
||||||
|
index_name = index['index_name']
|
||||||
|
client = self.aoss_client
|
||||||
|
if client.indices.exists(index=index_name):
|
||||||
|
client.indices.delete(index=index_name)
|
||||||
|
|
||||||
|
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||||
|
for index in aoss_indices:
|
||||||
|
if name.lower() == index['index_name']:
|
||||||
|
index['query']['query']['multi_match']['query'] = query_text
|
||||||
|
query = {'size': limit, 'query': index['query']}
|
||||||
|
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||||
|
return resp
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
|
for index in aoss_indices:
|
||||||
|
if name.lower() == index['index_name']:
|
||||||
|
to_index = []
|
||||||
|
for d in data:
|
||||||
|
item = {'_index': name}
|
||||||
|
for p in index['body']['mappings']['properties']:
|
||||||
|
item[p] = d[p]
|
||||||
|
to_index.append(item)
|
||||||
|
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||||
|
if failed > 0:
|
||||||
|
return success
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,9 @@ import logging
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
from neo4j import AsyncGraphDatabase, EagerResult
|
from neo4j import AsyncGraphDatabase, EagerResult
|
||||||
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
|
@ -29,7 +31,15 @@ logger = logging.getLogger(__name__)
|
||||||
class Neo4jDriver(GraphDriver):
|
class Neo4jDriver(GraphDriver):
|
||||||
provider = GraphProvider.NEO4J
|
provider = GraphProvider.NEO4J
|
||||||
|
|
||||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
def __init__(
|
||||||
|
self,
|
||||||
|
uri: str,
|
||||||
|
user: str | None,
|
||||||
|
password: str | None,
|
||||||
|
database: str = 'neo4j',
|
||||||
|
aoss_host: str | None = None,
|
||||||
|
aoss_port: int | None = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = AsyncGraphDatabase.driver(
|
self.client = AsyncGraphDatabase.driver(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
|
|
@ -37,6 +47,20 @@ class Neo4jDriver(GraphDriver):
|
||||||
)
|
)
|
||||||
self._database = database
|
self._database = database
|
||||||
|
|
||||||
|
self.aoss_client = None
|
||||||
|
if aoss_host and aoss_port:
|
||||||
|
session = boto3.Session()
|
||||||
|
self.aoss_client = OpenSearch(
|
||||||
|
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||||
|
http_auth=Urllib3AWSV4SignerAuth(
|
||||||
|
session.get_credentials(), session.region_name, 'aoss'
|
||||||
|
),
|
||||||
|
use_ssl=True,
|
||||||
|
verify_certs=True,
|
||||||
|
connection_class=Urllib3HttpConnection,
|
||||||
|
pool_maxsize=20,
|
||||||
|
)
|
||||||
|
|
||||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||||
# Check if database_ is provided in kwargs.
|
# Check if database_ is provided in kwargs.
|
||||||
# If not populated, set the value to retain backwards compatibility
|
# If not populated, set the value to retain backwards compatibility
|
||||||
|
|
@ -61,6 +85,8 @@ class Neo4jDriver(GraphDriver):
|
||||||
return await self.client.close()
|
return await self.client.close()
|
||||||
|
|
||||||
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
||||||
|
if self.aoss_client:
|
||||||
|
self.delete_all_indexes_impl()
|
||||||
return self.client.execute_query(
|
return self.client.execute_query(
|
||||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,86 +24,9 @@ import boto3
|
||||||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
DEFAULT_SIZE = 10
|
|
||||||
|
|
||||||
aoss_indices = [
|
|
||||||
{
|
|
||||||
'index_name': 'node_name_and_summary',
|
|
||||||
'body': {
|
|
||||||
'mappings': {
|
|
||||||
'properties': {
|
|
||||||
'uuid': {'type': 'keyword'},
|
|
||||||
'name': {'type': 'text'},
|
|
||||||
'summary': {'type': 'text'},
|
|
||||||
'group_id': {'type': 'text'},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'query': {
|
|
||||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'index_name': 'community_name',
|
|
||||||
'body': {
|
|
||||||
'mappings': {
|
|
||||||
'properties': {
|
|
||||||
'uuid': {'type': 'keyword'},
|
|
||||||
'name': {'type': 'text'},
|
|
||||||
'group_id': {'type': 'text'},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'query': {
|
|
||||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'index_name': 'episode_content',
|
|
||||||
'body': {
|
|
||||||
'mappings': {
|
|
||||||
'properties': {
|
|
||||||
'uuid': {'type': 'keyword'},
|
|
||||||
'content': {'type': 'text'},
|
|
||||||
'source': {'type': 'text'},
|
|
||||||
'source_description': {'type': 'text'},
|
|
||||||
'group_id': {'type': 'text'},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'query': {
|
|
||||||
'query': {
|
|
||||||
'multi_match': {
|
|
||||||
'query': '',
|
|
||||||
'fields': ['content', 'source', 'source_description', 'group_id'],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'index_name': 'edge_name_and_fact',
|
|
||||||
'body': {
|
|
||||||
'mappings': {
|
|
||||||
'properties': {
|
|
||||||
'uuid': {'type': 'keyword'},
|
|
||||||
'name': {'type': 'text'},
|
|
||||||
'fact': {'type': 'text'},
|
|
||||||
'group_id': {'type': 'text'},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'query': {
|
|
||||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class NeptuneDriver(GraphDriver):
|
class NeptuneDriver(GraphDriver):
|
||||||
|
|
@ -223,52 +146,6 @@ class NeptuneDriver(GraphDriver):
|
||||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||||
return self.delete_all_indexes_impl()
|
return self.delete_all_indexes_impl()
|
||||||
|
|
||||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
||||||
# No matter what happens above, always return True
|
|
||||||
return self.delete_aoss_indices()
|
|
||||||
|
|
||||||
async def create_aoss_indices(self):
|
|
||||||
for index in aoss_indices:
|
|
||||||
index_name = index['index_name']
|
|
||||||
client = self.aoss_client
|
|
||||||
if not client.indices.exists(index=index_name):
|
|
||||||
client.indices.create(index=index_name, body=index['body'])
|
|
||||||
# Sleep for 1 minute to let the index creation complete
|
|
||||||
await asyncio.sleep(60)
|
|
||||||
|
|
||||||
async def delete_aoss_indices(self):
|
|
||||||
for index in aoss_indices:
|
|
||||||
index_name = index['index_name']
|
|
||||||
client = self.aoss_client
|
|
||||||
if client.indices.exists(index=index_name):
|
|
||||||
client.indices.delete(index=index_name)
|
|
||||||
|
|
||||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
|
||||||
for index in aoss_indices:
|
|
||||||
if name.lower() == index['index_name']:
|
|
||||||
index['query']['query']['multi_match']['query'] = query_text
|
|
||||||
query = {'size': limit, 'query': index['query']}
|
|
||||||
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
|
||||||
return resp
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
||||||
for index in aoss_indices:
|
|
||||||
if name.lower() == index['index_name']:
|
|
||||||
to_index = []
|
|
||||||
for d in data:
|
|
||||||
item = {'_index': name}
|
|
||||||
for p in index['body']['mappings']['properties']:
|
|
||||||
item[p] = d[p]
|
|
||||||
to_index.append(item)
|
|
||||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
||||||
if failed > 0:
|
|
||||||
return success
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class NeptuneDriverSession(GraphDriverSession):
|
class NeptuneDriverSession(GraphDriverSession):
|
||||||
provider = GraphProvider.NEPTUNE
|
provider = GraphProvider.NEPTUNE
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.20.4"
|
version = "0.21.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.20.4"
|
version = "0.21.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue