async update

This commit is contained in:
prestonrasmussen 2025-09-14 00:53:38 -04:00
parent 42bbd93c38
commit e0066ff235
7 changed files with 117 additions and 89 deletions

View file

@ -25,6 +25,7 @@ from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.nodes import EpisodeType from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data from graphiti_core.utils.maintenance.graph_data_operations import clear_data
@ -34,6 +35,8 @@ load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
aoss_host = os.environ.get('AOSS_HOST') or None
aoss_port = os.environ.get('AOSS_PORT') or None
def setup_logging(): def setup_logging():
@ -77,12 +80,25 @@ class IsPresidentOf(BaseModel):
async def main(use_bulk: bool = False): async def main(use_bulk: bool = False):
setup_logging() setup_logging()
client = Graphiti( graph_driver = Neo4jDriver(
neo4j_uri, neo4j_uri,
neo4j_user, neo4j_user,
neo4j_password, neo4j_password,
aoss_host=aoss_host,
aoss_port=int(aoss_port),
aws_profile_name='zep-development',
aws_region='us-west-2',
aws_service='es',
) )
# client = Graphiti(
# neo4j_uri,
# neo4j_user,
# neo4j_password,
# )
client = Graphiti(graph_driver=graph_driver)
await clear_data(client.driver) await clear_data(client.driver)
await client.driver.delete_aoss_indices()
await client.driver.create_aoss_indices()
await client.build_indices_and_constraints() await client.build_indices_and_constraints()
messages = parse_podcast_messages() messages = parse_podcast_messages()
group_id = str(uuid4()) group_id = str(uuid4())

View file

@ -24,10 +24,12 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from dotenv import load_dotenv
from graphiti_core.embedder.client import EMBEDDING_DIM from graphiti_core.embedder.client import EMBEDDING_DIM
try: try:
from opensearchpy import OpenSearch, helpers from opensearchpy import AsyncOpenSearch, helpers
_HAS_OPENSEARCH = True _HAS_OPENSEARCH = True
except ImportError: except ImportError:
@ -39,6 +41,8 @@ logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10 DEFAULT_SIZE = 10
load_dotenv()
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities') ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes') EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities') COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
@ -62,7 +66,7 @@ aoss_indices = [
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'summary': {'type': 'text'}, 'summary': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'name_embedding': { 'name_embedding': {
'type': 'knn_vector', 'type': 'knn_vector',
@ -85,7 +89,7 @@ aoss_indices = [
'properties': { 'properties': {
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
} }
} }
}, },
@ -99,7 +103,7 @@ aoss_indices = [
'content': {'type': 'text'}, 'content': {'type': 'text'},
'source': {'type': 'text'}, 'source': {'type': 'text'},
'source_description': {'type': 'text'}, 'source_description': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
} }
@ -115,7 +119,7 @@ aoss_indices = [
'uuid': {'type': 'keyword'}, 'uuid': {'type': 'keyword'},
'name': {'type': 'text'}, 'name': {'type': 'text'},
'fact': {'type': 'text'}, 'fact': {'type': 'text'},
'group_id': {'type': 'text'}, 'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
@ -167,7 +171,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries '' # Neo4j (default) syntax does not require a prefix for fulltext queries
) )
_database: str _database: str
aoss_client: OpenSearch | None # type: ignore aoss_client: AsyncOpenSearch | None # type: ignore
@abstractmethod @abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -209,7 +213,7 @@ class GraphDriver(ABC):
alias_name = index['index_name'] alias_name = index['index_name']
# If alias already exists, skip (idempotent behavior) # If alias already exists, skip (idempotent behavior)
if client.indices.exists_alias(name=alias_name): if await client.indices.exists_alias(name=alias_name):
continue continue
# Build a physical index name with timestamp # Build a physical index name with timestamp
@ -217,10 +221,10 @@ class GraphDriver(ABC):
physical_index_name = f'{alias_name}_{ts_suffix}' physical_index_name = f'{alias_name}_{ts_suffix}'
# Create the index # Create the index
client.indices.create(index=physical_index_name, body=index['body']) await client.indices.create(index=physical_index_name, body=index['body'])
# Point alias to it # Point alias to it
client.indices.put_alias(index=physical_index_name, name=alias_name) await client.indices.put_alias(index=physical_index_name, name=alias_name)
# Allow some time for index creation # Allow some time for index creation
await asyncio.sleep(1) await asyncio.sleep(1)
@ -237,7 +241,7 @@ class GraphDriver(ABC):
try: try:
# Resolve alias → indices # Resolve alias → indices
alias_info = client.indices.get_alias(name=alias_name) alias_info = await client.indices.get_alias(name=alias_name)
indices = list(alias_info.keys()) indices = list(alias_info.keys())
if not indices: if not indices:
@ -245,8 +249,8 @@ class GraphDriver(ABC):
continue continue
for index in indices: for index in indices:
if client.indices.exists(index=index): if await client.indices.exists(index=index):
client.indices.delete(index=index) await client.indices.delete(index=index)
logger.info(f"Deleted index '{index}' (alias: {alias_name})") logger.info(f"Deleted index '{index}' (alias: {alias_name})")
else: else:
logger.warning(f"Index '{index}' not found for alias '{alias_name}'") logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
@ -264,14 +268,16 @@ class GraphDriver(ABC):
for index in aoss_indices: for index in aoss_indices:
index_name = index['index_name'] index_name = index['index_name']
if client.indices.exists(index=index_name): if await client.indices.exists(index=index_name):
try: try:
# Delete all documents but keep the index # Delete all documents but keep the index
response = client.delete_by_query( response = await client.delete_by_query(
index=index_name, index=index_name,
body={'query': {'match_all': {}}}, body={'query': {'match_all': {}}},
refresh=True, refresh=True,
conflicts='proceed', conflicts='proceed',
wait_for_completion=True,
slices='auto', # improves coverage/concurrency
) )
logger.info(f"Cleared index '{index_name}': {response}") logger.info(f"Cleared index '{index_name}': {response}")
except Exception as e: except Exception as e:
@ -281,7 +287,7 @@ class GraphDriver(ABC):
async def save_to_aoss(self, name: str, data: list[dict]) -> int: async def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client client = self.aoss_client
if not client or not helpers: if not client:
logger.warning('No OpenSearch client found') logger.warning('No OpenSearch client found')
return 0 return 0
@ -289,16 +295,20 @@ class GraphDriver(ABC):
if name.lower() == index['index_name']: if name.lower() == index['index_name']:
to_index = [] to_index = []
for d in data: for d in data:
item = { doc = {}
'_index': name,
'_routing': d.get('group_id'), # shard routing
}
for p in index['body']['mappings']['properties']: for p in index['body']['mappings']['properties']:
if p in d: # protect against missing fields if p in d: # protect against missing fields
item[p] = d[p] doc[p] = d[p]
item = {
'_index': name,
'_id': d['uuid'],
'_routing': d.get('group_id'),
'_source': doc,
}
to_index.append(item) to_index.append(item)
success, failed = helpers.bulk( success, failed = await helpers.async_bulk(
client, to_index, stats_only=True, request_timeout=60 client, to_index, stats_only=True, request_timeout=60
) )

View file

@ -29,8 +29,9 @@ logger = logging.getLogger(__name__)
try: try:
import boto3 import boto3
from opensearchpy import ( from opensearchpy import (
AIOHttpConnection,
AsyncOpenSearch,
AWSV4SignerAuth, AWSV4SignerAuth,
OpenSearch,
RequestsHttpConnection, RequestsHttpConnection,
Urllib3AWSV4SignerAuth, Urllib3AWSV4SignerAuth,
Urllib3HttpConnection, Urllib3HttpConnection,
@ -75,12 +76,12 @@ class Neo4jDriver(GraphDriver):
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials() credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
auth = AWSV4SignerAuth(credentials, region or '', service or '') auth = AWSV4SignerAuth(credentials, region or '', service or '')
self.aoss_client = OpenSearch( self.aoss_client = AsyncOpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}], hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=auth, auth=auth,
use_ssl=True, use_ssl=True,
verify_certs=True, verify_certs=True,
connection_class=RequestsHttpConnection, connection_class=AIOHttpConnection,
pool_maxsize=20, pool_maxsize=20,
) # type: ignore ) # type: ignore
except Exception as e: except Exception as e:

View file

@ -267,7 +267,7 @@ class EntityEdge(Edge):
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
""" """
elif driver.aoss_client: elif driver.aoss_client:
resp = driver.aoss_client.search( resp = await driver.aoss_client.search(
body={ body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1, 'size': 1,

View file

@ -513,7 +513,7 @@ class EntityNode(Node):
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
""" """
elif driver.aoss_client: elif driver.aoss_client:
resp = driver.aoss_client.search( resp = await driver.aoss_client.search(
body={ body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1, 'size': 1,

View file

@ -254,7 +254,7 @@ async def edge_fulltext_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -415,7 +415,7 @@ async def edge_similarity_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter) filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME, index=ENTITY_EDGE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -661,7 +661,7 @@ async def node_fulltext_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -674,7 +674,7 @@ async def node_fulltext_search(
{ {
'multi_match': { 'multi_match': {
'query': query, 'query': query,
'fields': ['name', 'summary'], # ✅ fixed key 'fields': ['name', 'summary'],
'operator': 'or', 'operator': 'or',
} }
} }
@ -806,7 +806,7 @@ async def node_similarity_search(
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter) filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search( res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME, index=ENTITY_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],
@ -1011,7 +1011,7 @@ async def episode_fulltext_search(
return [] return []
elif driver.aoss_client: elif driver.aoss_client:
route = group_ids[0] if group_ids else None route = group_ids[0] if group_ids else None
res = driver.aoss_client.search( res = await driver.aoss_client.search(
EPISODE_INDEX_NAME, EPISODE_INDEX_NAME,
routing=route, routing=route,
_source=['uuid'], _source=['uuid'],

View file

@ -17,6 +17,7 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
from xml.dom.minidom import Entity
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import LiteralString from typing_extensions import LiteralString
@ -260,7 +261,7 @@ async def resolve_extracted_edges(
embedder = clients.embedder embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges) await create_entity_edge_embeddings(embedder, extracted_edges)
valid_uuids_list: list[list[str]] = await semaphore_gather( valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
*[ *[
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid) EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
for edge in extracted_edges for edge in extracted_edges
@ -274,9 +275,9 @@ async def resolve_extracted_edges(
extracted_edge.fact, extracted_edge.fact,
group_ids=[extracted_edge.group_id], group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF, config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(edge_uuids=valid_uuids), search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
) )
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True) for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
] ]
) )