This commit is contained in:
prestonrasmussen 2025-09-14 01:23:13 -04:00
parent e0066ff235
commit fd1c360e8c
7 changed files with 60 additions and 80 deletions

View file

@ -25,7 +25,6 @@ from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.nodes import EpisodeType from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data from graphiti_core.utils.maintenance.graph_data_operations import clear_data
@ -35,8 +34,6 @@ load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
aoss_host = os.environ.get('AOSS_HOST') or None
aoss_port = os.environ.get('AOSS_PORT') or None
def setup_logging(): def setup_logging():
@ -80,25 +77,12 @@ class IsPresidentOf(BaseModel):
async def main(use_bulk: bool = False): async def main(use_bulk: bool = False):
setup_logging() setup_logging()
graph_driver = Neo4jDriver( client = Graphiti(
neo4j_uri, neo4j_uri,
neo4j_user, neo4j_user,
neo4j_password, neo4j_password,
aoss_host=aoss_host,
aoss_port=int(aoss_port),
aws_profile_name='zep-development',
aws_region='us-west-2',
aws_service='es',
) )
# client = Graphiti(
# neo4j_uri,
# neo4j_user,
# neo4j_password,
# )
client = Graphiti(graph_driver=graph_driver)
await clear_data(client.driver) await clear_data(client.driver)
await client.driver.delete_aoss_indices()
await client.driver.create_aoss_indices()
await client.build_indices_and_constraints() await client.build_indices_and_constraints()
messages = parse_podcast_messages() messages = parse_podcast_messages()
group_id = str(uuid4()) group_id = str(uuid4())

View file

@ -274,10 +274,7 @@ class GraphDriver(ABC):
response = await client.delete_by_query( response = await client.delete_by_query(
index=index_name, index=index_name,
body={'query': {'match_all': {}}}, body={'query': {'match_all': {}}},
refresh=True, slices='auto',
conflicts='proceed',
wait_for_completion=True,
slices='auto', # improves coverage/concurrency
) )
logger.info(f"Cleared index '{index_name}': {response}") logger.info(f"Cleared index '{index_name}': {response}")
except Exception as e: except Exception as e:
@ -287,7 +284,7 @@ class GraphDriver(ABC):
async def save_to_aoss(self, name: str, data: list[dict]) -> int: async def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client client = self.aoss_client
if not client: if not client or not helpers:
logger.warning('No OpenSearch client found') logger.warning('No OpenSearch client found')
return 0 return 0

View file

@ -32,7 +32,6 @@ try:
AIOHttpConnection, AIOHttpConnection,
AsyncOpenSearch, AsyncOpenSearch,
AWSV4SignerAuth, AWSV4SignerAuth,
RequestsHttpConnection,
Urllib3AWSV4SignerAuth, Urllib3AWSV4SignerAuth,
Urllib3HttpConnection, Urllib3HttpConnection,
) )

View file

@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver):
'You must provide an AOSS endpoint to create an OpenSearch driver.' 'You must provide an AOSS endpoint to create an OpenSearch driver.'
) )
if not client.indices.exists(index=index_name): if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body']) await client.indices.create(index=index_name, body=index['body'])
alias_name = index.get('alias_name', index_name) alias_name = index.get('alias_name', index_name)
if not client.indices.exists_alias(name=alias_name, index=index_name): if not client.indices.exists_alias(name=alias_name, index=index_name):
client.indices.put_alias(index=index_name, name=alias_name) await client.indices.put_alias(index=index_name, name=alias_name)
# Sleep for 1 minute to let the index creation complete # Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60) await asyncio.sleep(60)

View file

@ -103,8 +103,9 @@ class Node(BaseModel, ABC):
case GraphProvider.NEO4J: case GraphProvider.NEO4J:
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n {uuid: $uuid})-[r]-() MATCH (n {uuid: $uuid})
WHERE n:Entity OR n:Episodic OR n:Community WHERE n:Entity OR n:Episodic OR n:Community
OPTIONAL MATCH (n)-[r]-()
WITH collect(r.uuid) AS edge_uuids, n WITH collect(r.uuid) AS edge_uuids, n
DETACH DELETE n DETACH DELETE n
RETURN edge_uuids RETURN edge_uuids

View file

@ -808,7 +808,7 @@ async def node_similarity_search(
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, params={'routing': route},
_source=['uuid'], _source=['uuid'],
size=limit, size=limit,
body={ body={

View file

@ -17,7 +17,6 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
from xml.dom.minidom import Entity
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import LiteralString from typing_extensions import LiteralString