remove generic aoss_client interactions for release build (#975)
* remove generic aoss_client interactions for release build * remove unused imports * update * revert changes to Neptune driver * Update graphiti_core/driver/neptune_driver.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> * default to sync OpenSearch client * update * aoss_client now Any type * update stubs --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
This commit is contained in:
parent
35857fa211
commit
5a67e660dc
5 changed files with 52 additions and 289 deletions
|
|
@ -14,29 +14,16 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.embedder.client import EMBEDDING_DIM
|
||||
|
||||
try:
|
||||
from opensearchpy import AsyncOpenSearch, helpers
|
||||
|
||||
_HAS_OPENSEARCH = True
|
||||
except ImportError:
|
||||
OpenSearch = None
|
||||
helpers = None
|
||||
_HAS_OPENSEARCH = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SIZE = 10
|
||||
|
|
@ -56,91 +43,6 @@ class GraphProvider(Enum):
|
|||
NEPTUNE = 'neptune'
|
||||
|
||||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': ENTITY_INDEX_NAME,
|
||||
'body': {
|
||||
'settings': {'index': {'knn': True}},
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'summary': {'type': 'text'},
|
||||
'group_id': {'type': 'keyword'},
|
||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'name_embedding': {
|
||||
'type': 'knn_vector',
|
||||
'dimension': EMBEDDING_DIM,
|
||||
'method': {
|
||||
'engine': 'faiss',
|
||||
'space_type': 'cosinesimil',
|
||||
'name': 'hnsw',
|
||||
'parameters': {'ef_construction': 128, 'm': 16},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': COMMUNITY_INDEX_NAME,
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'group_id': {'type': 'keyword'},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': EPISODE_INDEX_NAME,
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'content': {'type': 'text'},
|
||||
'source': {'type': 'text'},
|
||||
'source_description': {'type': 'text'},
|
||||
'group_id': {'type': 'keyword'},
|
||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': ENTITY_EDGE_INDEX_NAME,
|
||||
'body': {
|
||||
'settings': {'index': {'knn': True}},
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'fact': {'type': 'text'},
|
||||
'group_id': {'type': 'keyword'},
|
||||
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||
'fact_embedding': {
|
||||
'type': 'knn_vector',
|
||||
'dimension': EMBEDDING_DIM,
|
||||
'method': {
|
||||
'engine': 'faiss',
|
||||
'space_type': 'cosinesimil',
|
||||
'name': 'hnsw',
|
||||
'parameters': {'ef_construction': 128, 'm': 16},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
provider: GraphProvider
|
||||
|
||||
|
|
@ -171,7 +73,7 @@ class GraphDriver(ABC):
|
|||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
_database: str
|
||||
aoss_client: AsyncOpenSearch | None # type: ignore
|
||||
aoss_client: Any # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
|
|
@ -199,119 +101,6 @@ class GraphDriver(ABC):
|
|||
|
||||
return cloned
|
||||
|
||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||
# No matter what happens above, always return True
|
||||
return self.delete_aoss_indices()
|
||||
|
||||
async def create_aoss_indices(self):
|
||||
client = self.aoss_client
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
for index in aoss_indices:
|
||||
alias_name = index['index_name']
|
||||
|
||||
# If alias already exists, skip (idempotent behavior)
|
||||
if await client.indices.exists_alias(name=alias_name):
|
||||
continue
|
||||
|
||||
# Build a physical index name with timestamp
|
||||
ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
||||
physical_index_name = f'{alias_name}_{ts_suffix}'
|
||||
|
||||
# Create the index
|
||||
await client.indices.create(index=physical_index_name, body=index['body'])
|
||||
|
||||
# Point alias to it
|
||||
await client.indices.put_alias(index=physical_index_name, name=alias_name)
|
||||
|
||||
# Allow some time for index creation
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def delete_aoss_indices(self):
|
||||
client = self.aoss_client
|
||||
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
for entry in aoss_indices:
|
||||
alias_name = entry['index_name']
|
||||
|
||||
try:
|
||||
# Resolve alias → indices
|
||||
alias_info = await client.indices.get_alias(name=alias_name)
|
||||
indices = list(alias_info.keys())
|
||||
|
||||
if not indices:
|
||||
logger.info(f"No indices found for alias '{alias_name}'")
|
||||
continue
|
||||
|
||||
for index in indices:
|
||||
if await client.indices.exists(index=index):
|
||||
await client.indices.delete(index=index)
|
||||
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
||||
else:
|
||||
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
|
||||
|
||||
async def clear_aoss_indices(self):
|
||||
client = self.aoss_client
|
||||
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
|
||||
if await client.indices.exists(index=index_name):
|
||||
try:
|
||||
# Delete all documents but keep the index
|
||||
response = await client.delete_by_query(
|
||||
index=index_name,
|
||||
body={'query': {'match_all': {}}},
|
||||
)
|
||||
logger.info(f"Cleared index '{index_name}': {response}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing index '{index_name}': {e}")
|
||||
else:
|
||||
logger.warning(f"Index '{index_name}' does not exist")
|
||||
|
||||
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
client = self.aoss_client
|
||||
if not client or not helpers:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return 0
|
||||
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
doc = {}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
if p in d: # protect against missing fields
|
||||
doc[p] = d[p]
|
||||
|
||||
item = {
|
||||
'_index': name,
|
||||
'_id': d['uuid'],
|
||||
'_routing': d.get('group_id'),
|
||||
'_source': doc,
|
||||
}
|
||||
to_index.append(item)
|
||||
|
||||
success, failed = await helpers.async_bulk(
|
||||
client, to_index, stats_only=True, request_timeout=60
|
||||
)
|
||||
|
||||
return success if failed == 0 else success
|
||||
|
||||
return 0
|
||||
|
||||
def build_fulltext_query(
|
||||
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
||||
) -> str:
|
||||
|
|
@ -320,3 +109,9 @@ class GraphDriver(ABC):
|
|||
Only implemented by providers that need custom fulltext query building.
|
||||
"""
|
||||
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|
||||
|
||||
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
return 0
|
||||
|
||||
async def clear_aoss_indices(self):
|
||||
return 1
|
||||
|
|
|
|||
|
|
@ -22,28 +22,9 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from opensearchpy import (
|
||||
AIOHttpConnection,
|
||||
AsyncOpenSearch,
|
||||
AWSV4SignerAuth,
|
||||
Urllib3AWSV4SignerAuth,
|
||||
Urllib3HttpConnection,
|
||||
)
|
||||
|
||||
_HAS_OPENSEARCH = True
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
OpenSearch = None
|
||||
Urllib3AWSV4SignerAuth = None
|
||||
Urllib3HttpConnection = None
|
||||
_HAS_OPENSEARCH = False
|
||||
|
||||
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider = GraphProvider.NEO4J
|
||||
|
|
@ -54,11 +35,6 @@ class Neo4jDriver(GraphDriver):
|
|||
user: str | None,
|
||||
password: str | None,
|
||||
database: str = 'neo4j',
|
||||
aoss_host: str | None = None,
|
||||
aoss_port: int | None = None,
|
||||
aws_profile_name: str | None = None,
|
||||
aws_region: str | None = None,
|
||||
aws_service: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
|
|
@ -68,24 +44,6 @@ class Neo4jDriver(GraphDriver):
|
|||
self._database = database
|
||||
|
||||
self.aoss_client = None
|
||||
if aoss_host and aoss_port and boto3 is not None:
|
||||
try:
|
||||
region = aws_region
|
||||
service = aws_service
|
||||
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
||||
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
||||
|
||||
self.aoss_client = AsyncOpenSearch(
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
auth=auth,
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=AIOHttpConnection,
|
||||
pool_maxsize=20,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to initialize OpenSearch client: {e}')
|
||||
self.aoss_client = None
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
# Check if database_ is provided in kwargs.
|
||||
|
|
@ -111,13 +69,6 @@ class Neo4jDriver(GraphDriver):
|
|||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine:
|
||||
if self.aoss_client:
|
||||
return semaphore_gather(
|
||||
self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
),
|
||||
self.delete_aoss_indices(),
|
||||
)
|
||||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,21 +22,16 @@ from typing import Any
|
|||
|
||||
import boto3
|
||||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
|
||||
from graphiti_core.driver.driver import (
|
||||
DEFAULT_SIZE,
|
||||
GraphDriver,
|
||||
GraphDriverSession,
|
||||
GraphProvider,
|
||||
)
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
neptune_aoss_indices = [
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
'alias_name': 'entities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -54,7 +49,6 @@ neptune_aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'community_name',
|
||||
'alias_name': 'communities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -71,7 +65,6 @@ neptune_aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'episode_content',
|
||||
'alias_name': 'episodes',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -95,7 +88,6 @@ neptune_aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'edge_name_and_fact',
|
||||
'alias_name': 'facts',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -228,27 +220,52 @@ class NeptuneDriver(GraphDriver):
|
|||
async def _delete_all_data(self) -> Any:
|
||||
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
return self.delete_all_indexes_impl()
|
||||
|
||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||
# No matter what happens above, always return True
|
||||
return self.delete_aoss_indices()
|
||||
|
||||
async def create_aoss_indices(self):
|
||||
for index in neptune_aoss_indices:
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if not client:
|
||||
raise ValueError(
|
||||
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
||||
)
|
||||
if not client.indices.exists(index=index_name):
|
||||
await client.indices.create(index=index_name, body=index['body'])
|
||||
|
||||
alias_name = index.get('alias_name', index_name)
|
||||
|
||||
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
||||
await client.indices.put_alias(index=index_name, name=alias_name)
|
||||
|
||||
client.indices.create(index=index_name, body=index['body'])
|
||||
# Sleep for 1 minute to let the index creation complete
|
||||
await asyncio.sleep(60)
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
return self.delete_all_indexes_impl()
|
||||
async def delete_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
index['query']['query']['multi_match']['query'] = query_text
|
||||
query = {'size': limit, 'query': index['query']}
|
||||
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||
return resp
|
||||
return {}
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {'_index': name, '_id': d['uuid']}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
if p in d:
|
||||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||
return success
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class NeptuneDriverSession(GraphDriverSession):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.21.0pre13"
|
||||
version = "0.21.0"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.21.0rc13"
|
||||
version = "0.21.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue