Number of Neo4j Connections Optimizations (#199)

* improve node distance performance

* update episode mentions

* format

* swap to debug log
This commit is contained in:
Preston Rasmussen 2024-10-23 13:08:47 -04:00 committed by GitHub
parent f77ab2b002
commit 47ba11e08d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 88 additions and 84 deletions

View file

@ -53,22 +53,23 @@ logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()
query = query.replace('\n', ' ')
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
edges, nodes, communities = await asyncio.gather(
edge_search(
driver,
embedder,
query,
query_vector,
group_ids,
config.edge_config,
center_node_uuid,
@ -76,8 +77,8 @@ async def search(
),
node_search(
driver,
embedder,
query,
query_vector,
group_ids,
config.node_config,
center_node_uuid,
@ -85,8 +86,8 @@ async def search(
),
community_search(
driver,
embedder,
query,
query_vector,
group_ids,
config.community_config,
config.limit,
@ -99,27 +100,25 @@ async def search(
communities=communities,
)
end = time()
latency = (time() - start) * 1000
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
logger.debug(f'search returned context for query {query} in {latency} ms')
return results
async def edge_search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
driver: AsyncDriver,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
if config is None:
return []
query_vector = await embedder.create(input=[query])
search_results: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
@ -176,19 +175,17 @@ async def edge_search(
async def node_search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
driver: AsyncDriver,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
if config is None:
return []
query_vector = await embedder.create(input=[query])
search_results: list[list[EntityNode]] = list(
await asyncio.gather(
*[
@ -230,18 +227,16 @@ async def node_search(
async def community_search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
driver: AsyncDriver,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
if config is None:
return []
query_vector = await embedder.create(input=[query])
search_results: list[list[CommunityNode]] = list(
await asyncio.gather(
*[

View file

@ -19,6 +19,7 @@ import logging
from collections import defaultdict
from time import time
import neo4j
import numpy as np
from neo4j import AsyncDriver, Query
@ -79,7 +80,9 @@ async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
@ -104,7 +107,9 @@ async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
@ -141,7 +146,9 @@ async def edge_fulltext_search(
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -158,7 +165,9 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit
""")
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
cypher_query,
{
@ -188,11 +197,11 @@ async def edge_similarity_search(
# vector similarity search over embedded facts
query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
r.uuid AS uuid,
@ -211,7 +220,9 @@ async def edge_similarity_search(
LIMIT $limit
""")
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
query,
{
@ -241,7 +252,9 @@ async def node_fulltext_search(
if fuzzy_query == '':
return []
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
@ -276,7 +289,9 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
@ -318,7 +333,9 @@ async def community_fulltext_search(
if fuzzy_query == '':
return []
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
@ -353,7 +370,9 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
@ -554,32 +573,27 @@ async def node_distance_reranker(
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
) -> list[str]:
# filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {}
# Find the shortest path to center node
query = Query("""
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
RETURN length(p) AS score
UNWIND $node_uuids AS node_uuid
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
RETURN length(p) AS score, node_uuid AS uuid
""")
path_results = await asyncio.gather(
*[
driver.execute_query(
query,
node_uuid=uuid,
center_uuid=center_node_uuid,
database_=DEFAULT_DATABASE,
)
for uuid in filtered_uuids
]
path_results, _, _ = await driver.execute_query(
query,
node_uuids=filtered_uuids,
center_uuid=center_node_uuid,
database_=DEFAULT_DATABASE,
)
for uuid, result in zip(filtered_uuids, path_results):
records = result[0]
record = records[0] if len(records) > 0 else None
distance: float = record['score'] if record is not None else float('inf')
scores[uuid] = distance
for result in path_results:
uuid = result['uuid']
score = result['score'] if 'score' in result else float('inf')
scores[uuid] = score
# rerank on shortest distance
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
@ -596,25 +610,20 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
scores: dict[str, float] = {}
# Find the shortest path to center node
query = Query("""
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
RETURN count(*) AS score
query = Query("""
UNWIND $node_uuids AS node_uuid
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
RETURN count(*) AS score, n.uuid AS uuid
""")
result_scores = await asyncio.gather(
*[
driver.execute_query(
query,
node_uuid=uuid,
database_=DEFAULT_DATABASE,
)
for uuid in sorted_uuids
]
results, _, _ = await driver.execute_query(
query,
node_uuids=sorted_uuids,
database_=DEFAULT_DATABASE,
)
for uuid, result in zip(sorted_uuids, result_scores):
record = result[0][0]
scores[uuid] = record['score']
for result in results:
scores[result['uuid']] = result['score']
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])