Number of Neo4j Connections Optimizations (#199)

* improve node distance performance

* update episode mentions

* format

* swap to debug log
This commit is contained in:
Preston Rasmussen 2024-10-23 13:08:47 -04:00 committed by GitHub
parent f77ab2b002
commit 47ba11e08d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 88 additions and 84 deletions

View file

@ -53,22 +53,23 @@ logger = logging.getLogger(__name__)
async def search( async def search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, embedder: EmbedderClient,
query: str, query: str,
group_ids: list[str] | None, group_ids: list[str] | None,
config: SearchConfig, config: SearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
) -> SearchResults: ) -> SearchResults:
start = time() start = time()
query = query.replace('\n', ' ') query_vector = await embedder.create(input=[query.replace('\n', ' ')])
# if group_ids is empty, set it to None # if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None group_ids = group_ids if group_ids else None
edges, nodes, communities = await asyncio.gather( edges, nodes, communities = await asyncio.gather(
edge_search( edge_search(
driver, driver,
embedder,
query, query,
query_vector,
group_ids, group_ids,
config.edge_config, config.edge_config,
center_node_uuid, center_node_uuid,
@ -76,8 +77,8 @@ async def search(
), ),
node_search( node_search(
driver, driver,
embedder,
query, query,
query_vector,
group_ids, group_ids,
config.node_config, config.node_config,
center_node_uuid, center_node_uuid,
@ -85,8 +86,8 @@ async def search(
), ),
community_search( community_search(
driver, driver,
embedder,
query, query,
query_vector,
group_ids, group_ids,
config.community_config, config.community_config,
config.limit, config.limit,
@ -99,27 +100,25 @@ async def search(
communities=communities, communities=communities,
) )
end = time() latency = (time() - start) * 1000
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms') logger.debug(f'search returned context for query {query} in {latency} ms')
return results return results
async def edge_search( async def edge_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, query: str,
query: str, query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: EdgeSearchConfig | None, config: EdgeSearchConfig | None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
if config is None: if config is None:
return [] return []
query_vector = await embedder.create(input=[query])
search_results: list[list[EntityEdge]] = list( search_results: list[list[EntityEdge]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
@ -176,19 +175,17 @@ async def edge_search(
async def node_search( async def node_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, query: str,
query: str, query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: NodeSearchConfig | None, config: NodeSearchConfig | None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
if config is None: if config is None:
return [] return []
query_vector = await embedder.create(input=[query])
search_results: list[list[EntityNode]] = list( search_results: list[list[EntityNode]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
@ -230,18 +227,16 @@ async def node_search(
async def community_search( async def community_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, query: str,
query: str, query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: CommunitySearchConfig | None, config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
if config is None: if config is None:
return [] return []
query_vector = await embedder.create(input=[query])
search_results: list[list[CommunityNode]] = list( search_results: list[list[CommunityNode]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[

View file

@ -19,6 +19,7 @@ import logging
from collections import defaultdict from collections import defaultdict
from time import time from time import time
import neo4j
import numpy as np import numpy as np
from neo4j import AsyncDriver, Query from neo4j import AsyncDriver, Query
@ -79,7 +80,9 @@ async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode] driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]: ) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes] episode_uuids = [episode.uuid for episode in episodes]
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
""" """
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
@ -104,7 +107,9 @@ async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode] driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]: ) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes] node_uuids = [node.uuid for node in nodes]
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
""" """
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
@ -141,7 +146,9 @@ async def edge_fulltext_search(
cypher_query = Query(""" cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query) CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -158,7 +165,9 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit ORDER BY score DESC LIMIT $limit
""") """)
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
cypher_query, cypher_query,
{ {
@ -188,11 +197,11 @@ async def edge_similarity_search(
# vector similarity search over embedded facts # vector similarity search over embedded facts
query = Query(""" query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid = $source_uuid) AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid = $target_uuid) AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score WHERE score > $min_score
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
@ -211,7 +220,9 @@ async def edge_similarity_search(
LIMIT $limit LIMIT $limit
""") """)
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
query, query,
{ {
@ -241,7 +252,9 @@ async def node_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
""" """
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
@ -276,7 +289,9 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE, min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]: ) -> list[EntityNode]:
# vector similarity search over entity names # vector similarity search over entity names
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
""" """
CYPHER runtime = parallel parallelRuntimeSupport=all CYPHER runtime = parallel parallelRuntimeSupport=all
@ -318,7 +333,9 @@ async def community_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
""" """
CALL db.index.fulltext.queryNodes("community_name", $query) CALL db.index.fulltext.queryNodes("community_name", $query)
@ -353,7 +370,9 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE, min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
# vector similarity search over entity names # vector similarity search over entity names
async with driver.session(database=DEFAULT_DATABASE) as session: async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run( result = await session.run(
""" """
CYPHER runtime = parallel parallelRuntimeSupport=all CYPHER runtime = parallel parallelRuntimeSupport=all
@ -554,32 +573,27 @@ async def node_distance_reranker(
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
) -> list[str]: ) -> list[str]:
# filter out node_uuid center node node uuid # filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids)) filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {} scores: dict[str, float] = {}
# Find the shortest path to center node # Find the shortest path to center node
query = Query(""" query = Query("""
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid}) UNWIND $node_uuids AS node_uuid
RETURN length(p) AS score MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
RETURN length(p) AS score, node_uuid AS uuid
""") """)
path_results = await asyncio.gather( path_results, _, _ = await driver.execute_query(
*[ query,
driver.execute_query( node_uuids=filtered_uuids,
query, center_uuid=center_node_uuid,
node_uuid=uuid, database_=DEFAULT_DATABASE,
center_uuid=center_node_uuid,
database_=DEFAULT_DATABASE,
)
for uuid in filtered_uuids
]
) )
for uuid, result in zip(filtered_uuids, path_results): for result in path_results:
records = result[0] uuid = result['uuid']
record = records[0] if len(records) > 0 else None score = result['score'] if 'score' in result else float('inf')
distance: float = record['score'] if record is not None else float('inf') scores[uuid] = score
scores[uuid] = distance
# rerank on shortest distance # rerank on shortest distance
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
@ -596,25 +610,20 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
scores: dict[str, float] = {} scores: dict[str, float] = {}
# Find the shortest path to center node # Find the shortest path to center node
query = Query(""" query = Query("""
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid}) UNWIND $node_uuids AS node_uuid
RETURN count(*) AS score MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
RETURN count(*) AS score, n.uuid AS uuid
""") """)
result_scores = await asyncio.gather( results, _, _ = await driver.execute_query(
*[ query,
driver.execute_query( node_uuids=sorted_uuids,
query, database_=DEFAULT_DATABASE,
node_uuid=uuid,
database_=DEFAULT_DATABASE,
)
for uuid in sorted_uuids
]
) )
for uuid, result in zip(sorted_uuids, result_scores): for result in results:
record = result[0][0] scores[result['uuid']] = result['score']
scores[uuid] = record['score']
# rerank on shortest distance # rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])