Number of Neo4j Connections Optimizations (#199)
* improve node distance performance * update episode mentions * format * swap to debug log
This commit is contained in:
parent
f77ab2b002
commit
47ba11e08d
2 changed files with 88 additions and 84 deletions
|
|
@ -53,22 +53,23 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
query = query.replace('\n', ' ')
|
||||
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
||||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids else None
|
||||
edges, nodes, communities = await asyncio.gather(
|
||||
edge_search(
|
||||
driver,
|
||||
embedder,
|
||||
query,
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.edge_config,
|
||||
center_node_uuid,
|
||||
|
|
@ -76,8 +77,8 @@ async def search(
|
|||
),
|
||||
node_search(
|
||||
driver,
|
||||
embedder,
|
||||
query,
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.node_config,
|
||||
center_node_uuid,
|
||||
|
|
@ -85,8 +86,8 @@ async def search(
|
|||
),
|
||||
community_search(
|
||||
driver,
|
||||
embedder,
|
||||
query,
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.community_config,
|
||||
config.limit,
|
||||
|
|
@ -99,27 +100,25 @@ async def search(
|
|||
communities=communities,
|
||||
)
|
||||
|
||||
end = time()
|
||||
latency = (time() - start) * 1000
|
||||
|
||||
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
||||
logger.debug(f'search returned context for query {query} in {latency} ms')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def edge_search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
query_vector = await embedder.create(input=[query])
|
||||
|
||||
search_results: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
|
|
@ -176,19 +175,17 @@ async def edge_search(
|
|||
|
||||
|
||||
async def node_search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
query_vector = await embedder.create(input=[query])
|
||||
|
||||
search_results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
|
|
@ -230,18 +227,16 @@ async def node_search(
|
|||
|
||||
|
||||
async def community_search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
query_vector = await embedder.create(input=[query])
|
||||
|
||||
search_results: list[list[CommunityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import logging
|
|||
from collections import defaultdict
|
||||
from time import time
|
||||
|
||||
import neo4j
|
||||
import numpy as np
|
||||
from neo4j import AsyncDriver, Query
|
||||
|
||||
|
|
@ -79,7 +80,9 @@ async def get_mentioned_nodes(
|
|||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[EntityNode]:
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
|
|
@ -104,7 +107,9 @@ async def get_communities_by_nodes(
|
|||
driver: AsyncDriver, nodes: list[EntityNode]
|
||||
) -> list[CommunityNode]:
|
||||
node_uuids = [node.uuid for node in nodes]
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||
|
|
@ -141,7 +146,9 @@ async def edge_fulltext_search(
|
|||
cypher_query = Query("""
|
||||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
|
||||
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
|
|
@ -158,7 +165,9 @@ async def edge_fulltext_search(
|
|||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
cypher_query,
|
||||
{
|
||||
|
|
@ -188,11 +197,11 @@ async def edge_similarity_search(
|
|||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
||||
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||
WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
|
|
@ -211,7 +220,9 @@ async def edge_similarity_search(
|
|||
LIMIT $limit
|
||||
""")
|
||||
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
query,
|
||||
{
|
||||
|
|
@ -241,7 +252,9 @@ async def node_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||
|
|
@ -276,7 +289,9 @@ async def node_similarity_search(
|
|||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
|
|
@ -318,7 +333,9 @@ async def community_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
|
|
@ -353,7 +370,9 @@ async def community_similarity_search(
|
|||
min_score=DEFAULT_MIN_SCORE,
|
||||
) -> list[CommunityNode]:
|
||||
# vector similarity search over entity names
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
|
|
@ -554,32 +573,27 @@ async def node_distance_reranker(
|
|||
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
||||
) -> list[str]:
|
||||
# filter out node_uuid center node node uuid
|
||||
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
||||
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
|
||||
RETURN length(p) AS score
|
||||
UNWIND $node_uuids AS node_uuid
|
||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
|
||||
RETURN length(p) AS score, node_uuid AS uuid
|
||||
""")
|
||||
|
||||
path_results = await asyncio.gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
node_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
for uuid in filtered_uuids
|
||||
]
|
||||
path_results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
node_uuids=filtered_uuids,
|
||||
center_uuid=center_node_uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
for uuid, result in zip(filtered_uuids, path_results):
|
||||
records = result[0]
|
||||
record = records[0] if len(records) > 0 else None
|
||||
distance: float = record['score'] if record is not None else float('inf')
|
||||
scores[uuid] = distance
|
||||
for result in path_results:
|
||||
uuid = result['uuid']
|
||||
score = result['score'] if 'score' in result else float('inf')
|
||||
scores[uuid] = score
|
||||
|
||||
# rerank on shortest distance
|
||||
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
|
@ -596,25 +610,20 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
|||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
|
||||
RETURN count(*) AS score
|
||||
query = Query("""
|
||||
UNWIND $node_uuids AS node_uuid
|
||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
||||
RETURN count(*) AS score, n.uuid AS uuid
|
||||
""")
|
||||
|
||||
result_scores = await asyncio.gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
node_uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
for uuid in sorted_uuids
|
||||
]
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
node_uuids=sorted_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
for uuid, result in zip(sorted_uuids, result_scores):
|
||||
record = result[0][0]
|
||||
scores[uuid] = record['score']
|
||||
for result in results:
|
||||
scores[result['uuid']] = result['score']
|
||||
|
||||
# rerank on shortest distance
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue