diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 4c926d08..dda3506f 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -38,7 +38,7 @@ class GraphProvider(Enum): aoss_indices = [ { - 'index_name': 'node_name_and_summary', + 'index_name': 'entities', 'body': { 'mappings': { 'properties': { @@ -68,7 +68,7 @@ aoss_indices = [ }, }, { - 'index_name': 'community_name', + 'index_name': 'communities', 'body': { 'mappings': { 'properties': { @@ -84,7 +84,7 @@ aoss_indices = [ }, }, { - 'index_name': 'episode_content', + 'index_name': 'episodes', 'body': { 'mappings': { 'properties': { @@ -109,7 +109,7 @@ aoss_indices = [ }, }, { - 'index_name': 'edge_name_and_fact', + 'index_name': 'entity_edges', 'body': { 'mappings': { 'properties': { diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index a8394bb2..9d77e4d3 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -255,6 +255,20 @@ class EntityEdge(Edge): MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding """ + elif driver.aoss_client: + resp = driver.aoss_client.search( + body={ + 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, + 'size': 1, + }, + index='entity_edges', + ) + + if resp['hits']['hits']: + self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding'] + return + else: + raise EdgeNotFoundError(self.uuid) if driver.provider == GraphProvider.KUZU: query = """ diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index b15977e7..b9d27197 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -273,20 +273,6 @@ class EpisodicNode(Node): ) async def save(self, driver: GraphDriver): - if driver.aoss_client: - driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue - 'episode_content', - [ - { - 'uuid': self.uuid, - 'group_id': self.group_id, - 'source': self.source.value, - 'content': self.content, - 'source_description': self.source_description, - } - ], - ) - episode_args = { 'uuid': self.uuid, 'name': self.name, @@ -299,6 +285,12 @@ class EpisodicNode(Node): 'source': self.source.value, } + if driver.aoss_client: + driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue + 'episode_content', + [episode_args], + ) + result = await driver.execute_query( get_episode_node_save_query(driver.provider), **episode_args ) @@ -433,6 +425,21 @@ class EntityNode(Node): MATCH (n:Entity {uuid: $uuid}) RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding """ + elif driver.aoss_client: + resp = driver.aoss_client.search( + body={ + 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, + 'size': 1, + }, + index='entities', + ) + + if resp['hits']['hits']: + self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding'] + return + else: + raise NodeNotFoundError(self.uuid) + else: query: LiteralString = """ MATCH (n:Entity {uuid: $uuid}) diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 801f816e..e8bbc86c 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): - if driver.provider == GraphProvider.NEPTUNE: + if driver.aoss_client: await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue] return if delete_existing: @@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo range_indices: list[LiteralString] = get_range_indices(driver.provider) - fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) + # Don't create fulltext indices if OpenSearch is being used + if not driver.aoss_client: + fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) if driver.provider == GraphProvider.KUZU: # Skip creating fulltext indices if they already exist. Need to do this manually @@ -149,9 +151,9 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) - WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + query_filter + """ RETURN