From 50d2308c93e334368f6b10164c8e0d7d5d3b341e Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 22 Oct 2024 10:01:56 -0400 Subject: [PATCH] Use sessions search (#197) * use sessions for searches * correct DB name * fix typo --- graphiti_core/edges.py | 26 +-- graphiti_core/nodes.py | 26 +-- graphiti_core/search/search_utils.py | 160 ++++++++++-------- .../utils/maintenance/community_operations.py | 10 +- .../maintenance/graph_data_operations.py | 2 +- 5 files changed, 122 insertions(+), 102 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 60add09e..a61880ca 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -54,7 +54,7 @@ class Edge(BaseModel, ABC): DELETE e """, uuid=self.uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Deleted Edge: {self.uuid}') @@ -82,7 +82,7 @@ class EpisodicEdge(Edge): uuid=self.uuid, group_id=self.group_id, created_at=self.created_at, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Saved edge to neo4j: {self.uuid}') @@ -102,7 +102,7 @@ class EpisodicEdge(Edge): e.created_at AS created_at """, uuid=uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_episodic_edge_from_record(record) for record in records] @@ -125,7 +125,7 @@ class EpisodicEdge(Edge): e.created_at AS created_at """, uuids=uuids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_episodic_edge_from_record(record) for record in records] @@ -148,7 +148,7 @@ class EpisodicEdge(Edge): e.created_at AS created_at """, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_episodic_edge_from_record(record) for record in records] @@ -202,7 +202,7 @@ class EntityEdge(Edge): expired_at=self.expired_at, valid_at=self.valid_at, invalid_at=self.invalid_at, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Saved edge to neo4j: {self.uuid}') @@ -229,7 +229,7 @@ class EntityEdge(Edge): e.invalid_at AS invalid_at """, uuid=uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -259,7 +259,7 @@ class EntityEdge(Edge): e.invalid_at AS invalid_at """, uuids=uuids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -289,7 +289,7 @@ class EntityEdge(Edge): e.invalid_at AS invalid_at """, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -308,7 +308,7 @@ class CommunityEdge(Edge): uuid=self.uuid, group_id=self.group_id, created_at=self.created_at, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Saved edge to neo4j: {self.uuid}') @@ -328,7 +328,7 @@ class CommunityEdge(Edge): e.created_at AS created_at """, uuid=uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_community_edge_from_record(record) for record in records] @@ -349,7 +349,7 @@ class CommunityEdge(Edge): e.created_at AS created_at """, uuids=uuids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_community_edge_from_record(record) for record in records] @@ -370,7 +370,7 @@ class CommunityEdge(Edge): e.created_at AS created_at """, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) edges = [get_community_edge_from_record(record) for record in records] diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index ef54ccf6..439aec80 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -90,7 +90,7 @@ class Node(BaseModel, ABC): DETACH DELETE n """, uuid=self.uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Deleted Node: {self.uuid}') @@ -136,7 +136,7 @@ class EpisodicNode(Node): created_at=self.created_at, valid_at=self.valid_at, source=self.source.value, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Saved Node to neo4j: {self.uuid}') @@ -158,7 +158,7 @@ class EpisodicNode(Node): e.source AS source """, uuid=uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) episodes = [get_episodic_node_from_record(record) for record in records] @@ -184,7 +184,7 @@ class EpisodicNode(Node): e.source AS source """, uuids=uuids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) episodes = [get_episodic_node_from_record(record) for record in records] @@ -207,7 +207,7 @@ class EpisodicNode(Node): e.source AS source """, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) episodes = [get_episodic_node_from_record(record) for record in records] @@ -237,7 +237,7 @@ class EntityNode(Node): summary=self.summary, name_embedding=self.name_embedding, created_at=self.created_at, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Saved Node to neo4j: {self.uuid}') @@ -258,7 +258,7 @@ class EntityNode(Node): n.summary AS summary """, uuid=uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -282,7 +282,7 @@ class EntityNode(Node): n.summary AS summary """, uuids=uuids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -303,7 +303,7 @@ class EntityNode(Node): n.summary AS summary """, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -324,7 +324,7 @@ class CommunityNode(Node): summary=self.summary, name_embedding=self.name_embedding, created_at=self.created_at, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) logger.debug(f'Saved Node to neo4j: {self.uuid}') @@ -354,7 +354,7 @@ class CommunityNode(Node): n.summary AS summary """, uuid=uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) nodes = [get_community_node_from_record(record) for record in records] @@ -378,7 +378,7 @@ class CommunityNode(Node): n.summary AS summary """, uuids=uuids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) communities = [get_community_node_from_record(record) for record in records] @@ -399,7 +399,7 @@ class CommunityNode(Node): n.summary AS summary """, group_ids=group_ids, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) communities = [get_community_node_from_record(record) for record in records] diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index ed406c85..6d139fa2 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -79,20 +79,21 @@ async def get_mentioned_nodes( driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] - records, _, _ = await driver.execute_query( - """ - MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids - RETURN DISTINCT - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding - n.created_at AS created_at, - n.summary AS summary - """, - uuids=episode_uuids, - _database=DEFAULT_DATABASE, - ) + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + """ + MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids + RETURN DISTINCT + n.uuid As uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary + """, + {'uuids': episode_uuids}, + ) + records = [record async for record in result] nodes = [get_entity_node_from_record(record) for record in records] @@ -103,8 +104,9 @@ async def get_communities_by_nodes( driver: AsyncDriver, nodes: list[EntityNode] ) -> list[CommunityNode]: node_uuids = [node.uuid for node in nodes] - records, _, _ = await driver.execute_query( - """ + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + """ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids RETURN DISTINCT c.uuid As uuid, @@ -114,9 +116,9 @@ async def get_communities_by_nodes( c.created_at AS created_at, c.summary AS summary """, - uuids=node_uuids, - _database=DEFAULT_DATABASE, - ) + {'uuids': node_uuids}, + ) + records = [record async for record in result] communities = [get_community_node_from_record(record) for record in records] @@ -156,15 +158,18 @@ async def edge_fulltext_search( ORDER BY score DESC LIMIT $limit """) - records, _, _ = await driver.execute_query( - cypher_query, - query=fuzzy_query, - source_uuid=source_node_uuid, - target_uuid=target_node_uuid, - group_ids=group_ids, - limit=limit, - _database=DEFAULT_DATABASE, - ) + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + cypher_query, + { + 'query': fuzzy_query, + 'source_uuid': source_node_uuid, + 'target_uuid': target_node_uuid, + 'group_ids': group_ids, + 'limit': limit, + }, + ) + records = [record async for record in result] edges = [get_entity_edge_from_record(record) for record in records] @@ -206,16 +211,19 @@ async def edge_similarity_search( LIMIT $limit """) - records, _, _ = await driver.execute_query( - query, - search_vector=search_vector, - source_uuid=source_node_uuid, - target_uuid=target_node_uuid, - group_ids=group_ids, - limit=limit, - min_score=min_score, - _database=DEFAULT_DATABASE, - ) + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + query, + { + 'search_vector': search_vector, + 'source_uuid': source_node_uuid, + 'target_uuid': target_node_uuid, + 'group_ids': group_ids, + 'limit': limit, + 'min_score': min_score, + }, + ) + records = [record async for record in result] edges = [get_entity_edge_from_record(record) for record in records] @@ -233,8 +241,9 @@ async def node_fulltext_search( if fuzzy_query == '': return [] - records, _, _ = await driver.execute_query( - """ + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + """ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) YIELD node AS n, score RETURN @@ -247,11 +256,13 @@ async def node_fulltext_search( ORDER BY score DESC LIMIT $limit """, - query=fuzzy_query, - group_ids=group_ids, - limit=limit, - _database=DEFAULT_DATABASE, - ) + { + 'query': fuzzy_query, + 'group_ids': group_ids, + 'limit': limit, + }, + ) + records = [record async for record in result] nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -265,8 +276,9 @@ async def node_similarity_search( min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityNode]: # vector similarity search over entity names - records, _, _ = await driver.execute_query( - """ + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + """ CYPHER runtime = parallel parallelRuntimeSupport=all MATCH (n:Entity) WHERE $group_ids IS NULL OR n.group_id IN $group_ids @@ -282,12 +294,14 @@ async def node_similarity_search( ORDER BY score DESC LIMIT $limit """, - search_vector=search_vector, - group_ids=group_ids, - limit=limit, - min_score=min_score, - _database=DEFAULT_DATABASE, - ) + { + 'search_vector': search_vector, + 'group_ids': group_ids, + 'limit': limit, + 'min_score': min_score, + }, + ) + records = [record async for record in result] nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -304,8 +318,9 @@ async def community_fulltext_search( if fuzzy_query == '': return [] - records, _, _ = await driver.execute_query( - """ + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + """ CALL db.index.fulltext.queryNodes("community_name", $query) YIELD node AS comm, score RETURN @@ -318,11 +333,13 @@ async def community_fulltext_search( ORDER BY score DESC LIMIT $limit """, - query=fuzzy_query, - group_ids=group_ids, - limit=limit, - _database=DEFAULT_DATABASE, - ) + { + 'query': fuzzy_query, + 'group_ids': group_ids, + 'limit': limit, + }, + ) + records = [record async for record in result] communities = [get_community_node_from_record(record) for record in records] return communities @@ -336,8 +353,9 @@ async def community_similarity_search( min_score=DEFAULT_MIN_SCORE, ) -> list[CommunityNode]: # vector similarity search over entity names - records, _, _ = await driver.execute_query( - """ + async with driver.session(database=DEFAULT_DATABASE) as session: + result = await session.run( + """ CYPHER runtime = parallel parallelRuntimeSupport=all MATCH (comm:Community) WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) @@ -353,12 +371,14 @@ async def community_similarity_search( ORDER BY score DESC LIMIT $limit """, - search_vector=search_vector, - group_ids=group_ids, - limit=limit, - min_score=min_score, - _database=DEFAULT_DATABASE, - ) + { + 'search_vector': search_vector, + 'group_ids': group_ids, + 'limit': limit, + 'min_score': min_score, + }, + ) + records = [record async for record in result] communities = [get_community_node_from_record(record) for record in records] return communities @@ -549,7 +569,7 @@ async def node_distance_reranker( query, node_uuid=uuid, center_uuid=center_node_uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) for uuid in filtered_uuids ] @@ -586,7 +606,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s driver.execute_query( query, node_uuid=uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) for uuid in sorted_uuids ] diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index dfbb456d..2a779377 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -40,7 +40,7 @@ async def get_community_clusters( RETURN collect(DISTINCT n.group_id) AS group_ids """, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) group_ids = group_id_values[0]['group_ids'] @@ -59,7 +59,7 @@ async def get_community_clusters( """, uuid=node.uuid, group_id=group_id, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) projection[node.uuid] = [ @@ -223,7 +223,7 @@ async def remove_communities(driver: AsyncDriver): MATCH (c:Community) DETACH DELETE c """, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) @@ -243,7 +243,7 @@ async def determine_entity_community( c.summary AS summary """, entity_uuid=entity.uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) if len(records) > 0: @@ -262,7 +262,7 @@ async def determine_entity_community( c.summary AS summary """, entity_uuid=entity.uuid, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) communities: list[CommunityNode] = [ diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 4c5bb900..89c9fdbd 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -35,7 +35,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo """ SHOW INDEXES YIELD name """, - _database=DEFAULT_DATABASE, + database_=DEFAULT_DATABASE, ) index_names = [record['name'] for record in records] await asyncio.gather(