diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 70201b9b..6a4d5468 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -25,6 +25,7 @@ from pydantic import BaseModel, Field from transcript_parser import parse_podcast_messages from graphiti_core import Graphiti +from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.nodes import EpisodeType from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.maintenance.graph_data_operations import clear_data @@ -34,6 +35,8 @@ load_dotenv() neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' +aoss_host = os.environ.get('AOSS_HOST') or None +aoss_port = os.environ.get('AOSS_PORT') or None def setup_logging(): @@ -77,12 +80,25 @@ class IsPresidentOf(BaseModel): async def main(use_bulk: bool = False): setup_logging() - client = Graphiti( + graph_driver = Neo4jDriver( neo4j_uri, neo4j_user, neo4j_password, + aoss_host=aoss_host, + aoss_port=int(aoss_port), + aws_profile_name='zep-development', + aws_region='us-west-2', + aws_service='es', ) + # client = Graphiti( + # neo4j_uri, + # neo4j_user, + # neo4j_password, + # ) + client = Graphiti(graph_driver=graph_driver) await clear_data(client.driver) + await client.driver.delete_aoss_indices() + await client.driver.create_aoss_indices() await client.build_indices_and_constraints() messages = parse_podcast_messages() group_id = str(uuid4()) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 52858582..ba99513f 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -24,10 +24,12 @@ from datetime import datetime from enum import Enum from typing import Any +from dotenv import load_dotenv + from graphiti_core.embedder.client import EMBEDDING_DIM try: - from opensearchpy import OpenSearch, helpers + from opensearchpy import AsyncOpenSearch, helpers _HAS_OPENSEARCH = True except ImportError: @@ -39,6 +41,8 @@ logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 +load_dotenv() + ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities') EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes') COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities') @@ -62,7 +66,7 @@ aoss_indices = [ 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'summary': {'type': 'text'}, - 'group_id': {'type': 'text'}, + 'group_id': {'type': 'keyword'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'name_embedding': { 'type': 'knn_vector', @@ -85,7 +89,7 @@ aoss_indices = [ 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, - 'group_id': {'type': 'text'}, + 'group_id': {'type': 'keyword'}, } } }, @@ -99,7 +103,7 @@ aoss_indices = [ 'content': {'type': 'text'}, 'source': {'type': 'text'}, 'source_description': {'type': 'text'}, - 'group_id': {'type': 'text'}, + 'group_id': {'type': 'keyword'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, } @@ -115,7 +119,7 @@ aoss_indices = [ 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'fact': {'type': 'text'}, - 'group_id': {'type': 'text'}, + 'group_id': {'type': 'keyword'}, 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, @@ -167,7 +171,7 @@ class GraphDriver(ABC): '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) _database: str - aoss_client: OpenSearch | None # type: ignore + aoss_client: AsyncOpenSearch | None # type: ignore @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: @@ -209,7 +213,7 @@ class GraphDriver(ABC): alias_name = index['index_name'] # If alias already exists, skip (idempotent behavior) - if client.indices.exists_alias(name=alias_name): + if await client.indices.exists_alias(name=alias_name): continue # Build a physical index name with timestamp @@ -217,10 +221,10 @@ class GraphDriver(ABC): physical_index_name = f'{alias_name}_{ts_suffix}' # Create the index - client.indices.create(index=physical_index_name, body=index['body']) + await client.indices.create(index=physical_index_name, body=index['body']) # Point alias to it - client.indices.put_alias(index=physical_index_name, name=alias_name) + await client.indices.put_alias(index=physical_index_name, name=alias_name) # Allow some time for index creation await asyncio.sleep(1) @@ -237,7 +241,7 @@ class GraphDriver(ABC): try: # Resolve alias → indices - alias_info = client.indices.get_alias(name=alias_name) + alias_info = await client.indices.get_alias(name=alias_name) indices = list(alias_info.keys()) if not indices: @@ -245,8 +249,8 @@ class GraphDriver(ABC): continue for index in indices: - if client.indices.exists(index=index): - client.indices.delete(index=index) + if await client.indices.exists(index=index): + await client.indices.delete(index=index) logger.info(f"Deleted index '{index}' (alias: {alias_name})") else: logger.warning(f"Index '{index}' not found for alias '{alias_name}'") @@ -264,14 +268,16 @@ class GraphDriver(ABC): for index in aoss_indices: index_name = index['index_name'] - if client.indices.exists(index=index_name): + if await client.indices.exists(index=index_name): try: # Delete all documents but keep the index - response = client.delete_by_query( + response = await client.delete_by_query( index=index_name, body={'query': {'match_all': {}}}, refresh=True, conflicts='proceed', + wait_for_completion=True, + slices='auto', # improves coverage/concurrency ) logger.info(f"Cleared index '{index_name}': {response}") except Exception as e: @@ -281,7 +287,7 @@ class GraphDriver(ABC): async def save_to_aoss(self, name: str, data: list[dict]) -> int: client = self.aoss_client - if not client or not helpers: + if not client: logger.warning('No OpenSearch client found') return 0 @@ -289,16 +295,20 @@ class GraphDriver(ABC): if name.lower() == index['index_name']: to_index = [] for d in data: - item = { - '_index': name, - '_routing': d.get('group_id'), # shard routing - } + doc = {} for p in index['body']['mappings']['properties']: if p in d: # protect against missing fields - item[p] = d[p] + doc[p] = d[p] + + item = { + '_index': name, + '_id': d['uuid'], + '_routing': d.get('group_id'), + '_source': doc, + } to_index.append(item) - success, failed = helpers.bulk( + success, failed = await helpers.async_bulk( client, to_index, stats_only=True, request_timeout=60 ) diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 9cf9041c..2365100e 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -29,8 +29,9 @@ logger = logging.getLogger(__name__) try: import boto3 from opensearchpy import ( + AIOHttpConnection, + AsyncOpenSearch, AWSV4SignerAuth, - OpenSearch, RequestsHttpConnection, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, @@ -75,12 +76,12 @@ class Neo4jDriver(GraphDriver): credentials = boto3.Session(profile_name=aws_profile_name).get_credentials() auth = AWSV4SignerAuth(credentials, region or '', service or '') - self.aoss_client = OpenSearch( + self.aoss_client = AsyncOpenSearch( hosts=[{'host': aoss_host, 'port': aoss_port}], - http_auth=auth, + auth=auth, use_ssl=True, verify_certs=True, - connection_class=RequestsHttpConnection, + connection_class=AIOHttpConnection, pool_maxsize=20, ) # type: ignore except Exception as e: diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 90a49762..f7d26b6f 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -267,7 +267,7 @@ class EntityEdge(Edge): RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding """ elif driver.aoss_client: - resp = driver.aoss_client.search( + resp = await driver.aoss_client.search( body={ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'size': 1, diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 2275bbbc..524dbd0f 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -513,7 +513,7 @@ class EntityNode(Node): RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding """ elif driver.aoss_client: - resp = driver.aoss_client.search( + resp = await driver.aoss_client.search( body={ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'size': 1, diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 3c475a20..f35261bc 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -215,11 +215,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -254,7 +254,7 @@ async def edge_fulltext_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_edge_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( + res = await driver.aoss_client.search( index=ENTITY_EDGE_INDEX_NAME, routing=route, _source=['uuid'], @@ -353,8 +353,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -415,7 +415,7 @@ async def edge_similarity_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_edge_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( + res = await driver.aoss_client.search( index=ENTITY_EDGE_INDEX_NAME, routing=route, _source=['uuid'], @@ -637,11 +637,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -661,7 +661,7 @@ async def node_fulltext_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( + res = await driver.aoss_client.search( index=ENTITY_INDEX_NAME, routing=route, _source=['uuid'], @@ -674,7 +674,7 @@ async def node_fulltext_search( { 'multi_match': { 'query': query, - 'fields': ['name', 'summary'], # ✅ fixed key + 'fields': ['name', 'summary'], 'operator': 'or', } } @@ -751,8 +751,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -781,11 +781,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -806,7 +806,7 @@ async def node_similarity_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( + res = await driver.aoss_client.search( index=ENTITY_INDEX_NAME, routing=route, _source=['uuid'], @@ -837,8 +837,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -1011,7 +1011,7 @@ async def episode_fulltext_search( return [] elif driver.aoss_client: route = group_ids[0] if group_ids else None - res = driver.aoss_client.search( + res = await driver.aoss_client.search( EPISODE_INDEX_NAME, routing=route, _source=['uuid'], @@ -1170,8 +1170,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1230,8 +1230,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1373,9 +1373,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1420,9 +1420,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1511,9 +1511,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1583,9 +1583,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1621,9 +1621,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1696,10 +1696,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1769,10 +1769,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1808,10 +1808,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 72e3dde1..19ad339c 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -17,6 +17,7 @@ limitations under the License. import logging from datetime import datetime from time import time +from xml.dom.minidom import Entity from pydantic import BaseModel from typing_extensions import LiteralString @@ -260,7 +261,7 @@ async def resolve_extracted_edges( embedder = clients.embedder await create_entity_edge_embeddings(embedder, extracted_edges) - valid_uuids_list: list[list[str]] = await semaphore_gather( + valid_edges_list: list[list[EntityEdge]] = await semaphore_gather( *[ EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid) for edge in extracted_edges @@ -274,9 +275,9 @@ async def resolve_extracted_edges( extracted_edge.fact, group_ids=[extracted_edge.group_id], config=EDGE_HYBRID_SEARCH_RRF, - search_filter=SearchFilters(edge_uuids=valid_uuids), + search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]), ) - for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True) + for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True) ] )