diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index d6661d24..47d6986c 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -53,22 +53,23 @@ logger = logging.getLogger(__name__) async def search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: SearchConfig, - center_node_uuid: str | None = None, + driver: AsyncDriver, + embedder: EmbedderClient, + query: str, + group_ids: list[str] | None, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() - query = query.replace('\n', ' ') + query_vector = await embedder.create(input=[query.replace('\n', ' ')]) + # if group_ids is empty, set it to None group_ids = group_ids if group_ids else None edges, nodes, communities = await asyncio.gather( edge_search( driver, - embedder, query, + query_vector, group_ids, config.edge_config, center_node_uuid, @@ -76,8 +77,8 @@ async def search( ), node_search( driver, - embedder, query, + query_vector, group_ids, config.node_config, center_node_uuid, @@ -85,8 +86,8 @@ async def search( ), community_search( driver, - embedder, query, + query_vector, group_ids, config.community_config, config.limit, @@ -99,27 +100,25 @@ async def search( communities=communities, ) - end = time() + latency = (time() - start) * 1000 - logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms') + logger.debug(f'search returned context for query {query} in {latency} ms') return results async def edge_search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: EdgeSearchConfig | None, - center_node_uuid: str | None = None, - limit=DEFAULT_SEARCH_LIMIT, + driver: AsyncDriver, + query: str, + query_vector: list[float], + group_ids: list[str] | None, + config: EdgeSearchConfig | None, + center_node_uuid: str | None = None, + limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityEdge]: if config is None: return [] - query_vector = await embedder.create(input=[query]) - search_results: list[list[EntityEdge]] = list( await asyncio.gather( *[ @@ -176,19 +175,17 @@ async def edge_search( async def node_search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: NodeSearchConfig | None, - center_node_uuid: str | None = None, - limit=DEFAULT_SEARCH_LIMIT, + driver: AsyncDriver, + query: str, + query_vector: list[float], + group_ids: list[str] | None, + config: NodeSearchConfig | None, + center_node_uuid: str | None = None, + limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityNode]: if config is None: return [] - query_vector = await embedder.create(input=[query]) - search_results: list[list[EntityNode]] = list( await asyncio.gather( *[ @@ -230,18 +227,16 @@ async def node_search( async def community_search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: CommunitySearchConfig | None, - limit=DEFAULT_SEARCH_LIMIT, + driver: AsyncDriver, + query: str, + query_vector: list[float], + group_ids: list[str] | None, + config: CommunitySearchConfig | None, + limit=DEFAULT_SEARCH_LIMIT, ) -> list[CommunityNode]: if config is None: return [] - query_vector = await embedder.create(input=[query]) - search_results: list[list[CommunityNode]] = list( await asyncio.gather( *[ diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 6d139fa2..57d8b58b 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -19,6 +19,7 @@ import logging from collections import defaultdict from time import time +import neo4j import numpy as np from neo4j import AsyncDriver, Query @@ -79,7 +80,9 @@ async def get_mentioned_nodes( driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( """ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids @@ -104,7 +107,9 @@ async def get_communities_by_nodes( driver: AsyncDriver, nodes: list[EntityNode] ) -> list[CommunityNode]: node_uuids = [node.uuid for node in nodes] - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( """ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids @@ -141,7 +146,9 @@ async def edge_fulltext_search( cypher_query = Query(""" CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query) YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity) + WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) + AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -158,7 +165,9 @@ async def edge_fulltext_search( ORDER BY score DESC LIMIT $limit """) - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( cypher_query, { @@ -188,11 +197,11 @@ async def edge_similarity_search( # vector similarity search over embedded facts query = Query(""" CYPHER runtime = parallel parallelRuntimeSupport=all - MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) - AND ($source_uuid IS NULL OR n.uuid = $source_uuid) - AND ($target_uuid IS NULL OR m.uuid = $target_uuid) - WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score + AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) + AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) + WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score WHERE score > $min_score RETURN r.uuid AS uuid, @@ -211,7 +220,9 @@ async def edge_similarity_search( LIMIT $limit """) - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( query, { @@ -241,7 +252,9 @@ async def node_fulltext_search( if fuzzy_query == '': return [] - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( """ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) @@ -276,7 +289,9 @@ async def node_similarity_search( min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityNode]: # vector similarity search over entity names - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( """ CYPHER runtime = parallel parallelRuntimeSupport=all @@ -318,7 +333,9 @@ async def community_fulltext_search( if fuzzy_query == '': return [] - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( """ CALL db.index.fulltext.queryNodes("community_name", $query) @@ -353,7 +370,9 @@ async def community_similarity_search( min_score=DEFAULT_MIN_SCORE, ) -> list[CommunityNode]: # vector similarity search over entity names - async with driver.session(database=DEFAULT_DATABASE) as session: + async with driver.session( + database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS + ) as session: result = await session.run( """ CYPHER runtime = parallel parallelRuntimeSupport=all @@ -554,32 +573,27 @@ async def node_distance_reranker( driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str ) -> list[str]: # filter out node_uuid center node node uuid - filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids)) + filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids)) scores: dict[str, float] = {} # Find the shortest path to center node query = Query(""" - MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid}) - RETURN length(p) AS score + UNWIND $node_uuids AS node_uuid + MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid}) + RETURN length(p) AS score, node_uuid AS uuid """) - path_results = await asyncio.gather( - *[ - driver.execute_query( - query, - node_uuid=uuid, - center_uuid=center_node_uuid, - database_=DEFAULT_DATABASE, - ) - for uuid in filtered_uuids - ] + path_results, _, _ = await driver.execute_query( + query, + node_uuids=filtered_uuids, + center_uuid=center_node_uuid, + database_=DEFAULT_DATABASE, ) - for uuid, result in zip(filtered_uuids, path_results): - records = result[0] - record = records[0] if len(records) > 0 else None - distance: float = record['score'] if record is not None else float('inf') - scores[uuid] = distance + for result in path_results: + uuid = result['uuid'] + score = result['score'] if 'score' in result else float('inf') + scores[uuid] = score # rerank on shortest distance filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) @@ -596,25 +610,20 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s scores: dict[str, float] = {} # Find the shortest path to center node - query = Query(""" - MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid}) - RETURN count(*) AS score + query = Query(""" + UNWIND $node_uuids AS node_uuid + MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid}) + RETURN count(*) AS score, n.uuid AS uuid """) - result_scores = await asyncio.gather( - *[ - driver.execute_query( - query, - node_uuid=uuid, - database_=DEFAULT_DATABASE, - ) - for uuid in sorted_uuids - ] + results, _, _ = await driver.execute_query( + query, + node_uuids=sorted_uuids, + database_=DEFAULT_DATABASE, ) - for uuid, result in zip(sorted_uuids, result_scores): - record = result[0][0] - scores[uuid] = record['score'] + for result in results: + scores[result['uuid']] = result['score'] # rerank on shortest distance sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])