Number of Neo4j Connections Optimizations (#199)
* improve node distance performance * update episode mentions * format * swap to debug log
This commit is contained in:
parent
f77ab2b002
commit
47ba11e08d
2 changed files with 88 additions and 84 deletions
|
|
@ -53,22 +53,23 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
query = query.replace('\n', ' ')
|
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
||||||
|
|
||||||
# if group_ids is empty, set it to None
|
# if group_ids is empty, set it to None
|
||||||
group_ids = group_ids if group_ids else None
|
group_ids = group_ids if group_ids else None
|
||||||
edges, nodes, communities = await asyncio.gather(
|
edges, nodes, communities = await asyncio.gather(
|
||||||
edge_search(
|
edge_search(
|
||||||
driver,
|
driver,
|
||||||
embedder,
|
|
||||||
query,
|
query,
|
||||||
|
query_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.edge_config,
|
config.edge_config,
|
||||||
center_node_uuid,
|
center_node_uuid,
|
||||||
|
|
@ -76,8 +77,8 @@ async def search(
|
||||||
),
|
),
|
||||||
node_search(
|
node_search(
|
||||||
driver,
|
driver,
|
||||||
embedder,
|
|
||||||
query,
|
query,
|
||||||
|
query_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.node_config,
|
config.node_config,
|
||||||
center_node_uuid,
|
center_node_uuid,
|
||||||
|
|
@ -85,8 +86,8 @@ async def search(
|
||||||
),
|
),
|
||||||
community_search(
|
community_search(
|
||||||
driver,
|
driver,
|
||||||
embedder,
|
|
||||||
query,
|
query,
|
||||||
|
query_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.community_config,
|
config.community_config,
|
||||||
config.limit,
|
config.limit,
|
||||||
|
|
@ -99,27 +100,25 @@ async def search(
|
||||||
communities=communities,
|
communities=communities,
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time()
|
latency = (time() - start) * 1000
|
||||||
|
|
||||||
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
logger.debug(f'search returned context for query {query} in {latency} ms')
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def edge_search(
|
async def edge_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
query: str,
|
||||||
query: str,
|
query_vector: list[float],
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: EdgeSearchConfig | None,
|
config: EdgeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_vector = await embedder.create(input=[query])
|
|
||||||
|
|
||||||
search_results: list[list[EntityEdge]] = list(
|
search_results: list[list[EntityEdge]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -176,19 +175,17 @@ async def edge_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_search(
|
async def node_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
query: str,
|
||||||
query: str,
|
query_vector: list[float],
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: NodeSearchConfig | None,
|
config: NodeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_vector = await embedder.create(input=[query])
|
|
||||||
|
|
||||||
search_results: list[list[EntityNode]] = list(
|
search_results: list[list[EntityNode]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -230,18 +227,16 @@ async def node_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_search(
|
async def community_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
query: str,
|
||||||
query: str,
|
query_vector: list[float],
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: CommunitySearchConfig | None,
|
config: CommunitySearchConfig | None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_vector = await embedder.create(input=[query])
|
|
||||||
|
|
||||||
search_results: list[list[CommunityNode]] = list(
|
search_results: list[list[CommunityNode]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
import neo4j
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
|
||||||
|
|
@ -79,7 +80,9 @@ async def get_mentioned_nodes(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
"""
|
"""
|
||||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||||
|
|
@ -104,7 +107,9 @@ async def get_communities_by_nodes(
|
||||||
driver: AsyncDriver, nodes: list[EntityNode]
|
driver: AsyncDriver, nodes: list[EntityNode]
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
node_uuids = [node.uuid for node in nodes]
|
node_uuids = [node.uuid for node in nodes]
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||||
|
|
@ -141,7 +146,9 @@ async def edge_fulltext_search(
|
||||||
cypher_query = Query("""
|
cypher_query = Query("""
|
||||||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
|
||||||
|
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||||
|
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -158,7 +165,9 @@ async def edge_fulltext_search(
|
||||||
ORDER BY score DESC LIMIT $limit
|
ORDER BY score DESC LIMIT $limit
|
||||||
""")
|
""")
|
||||||
|
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
cypher_query,
|
cypher_query,
|
||||||
{
|
{
|
||||||
|
|
@ -188,11 +197,11 @@ async def edge_similarity_search(
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
query = Query("""
|
query = Query("""
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||||
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||||
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
|
|
@ -211,7 +220,9 @@ async def edge_similarity_search(
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""")
|
""")
|
||||||
|
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
query,
|
query,
|
||||||
{
|
{
|
||||||
|
|
@ -241,7 +252,9 @@ async def node_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||||
|
|
@ -276,7 +289,9 @@ async def node_similarity_search(
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
"""
|
"""
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
|
|
@ -318,7 +333,9 @@ async def community_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||||
|
|
@ -353,7 +370,9 @@ async def community_similarity_search(
|
||||||
min_score=DEFAULT_MIN_SCORE,
|
min_score=DEFAULT_MIN_SCORE,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(
|
||||||
|
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||||
|
) as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
"""
|
"""
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
|
|
@ -554,32 +573,27 @@ async def node_distance_reranker(
|
||||||
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# filter out node_uuid center node node uuid
|
# filter out node_uuid center node node uuid
|
||||||
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = Query("""
|
query = Query("""
|
||||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
|
UNWIND $node_uuids AS node_uuid
|
||||||
RETURN length(p) AS score
|
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
|
||||||
|
RETURN length(p) AS score, node_uuid AS uuid
|
||||||
""")
|
""")
|
||||||
|
|
||||||
path_results = await asyncio.gather(
|
path_results, _, _ = await driver.execute_query(
|
||||||
*[
|
query,
|
||||||
driver.execute_query(
|
node_uuids=filtered_uuids,
|
||||||
query,
|
center_uuid=center_node_uuid,
|
||||||
node_uuid=uuid,
|
database_=DEFAULT_DATABASE,
|
||||||
center_uuid=center_node_uuid,
|
|
||||||
database_=DEFAULT_DATABASE,
|
|
||||||
)
|
|
||||||
for uuid in filtered_uuids
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for uuid, result in zip(filtered_uuids, path_results):
|
for result in path_results:
|
||||||
records = result[0]
|
uuid = result['uuid']
|
||||||
record = records[0] if len(records) > 0 else None
|
score = result['score'] if 'score' in result else float('inf')
|
||||||
distance: float = record['score'] if record is not None else float('inf')
|
scores[uuid] = score
|
||||||
scores[uuid] = distance
|
|
||||||
|
|
||||||
# rerank on shortest distance
|
# rerank on shortest distance
|
||||||
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
@ -596,25 +610,20 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = Query("""
|
query = Query("""
|
||||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
|
UNWIND $node_uuids AS node_uuid
|
||||||
RETURN count(*) AS score
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
||||||
|
RETURN count(*) AS score, n.uuid AS uuid
|
||||||
""")
|
""")
|
||||||
|
|
||||||
result_scores = await asyncio.gather(
|
results, _, _ = await driver.execute_query(
|
||||||
*[
|
query,
|
||||||
driver.execute_query(
|
node_uuids=sorted_uuids,
|
||||||
query,
|
database_=DEFAULT_DATABASE,
|
||||||
node_uuid=uuid,
|
|
||||||
database_=DEFAULT_DATABASE,
|
|
||||||
)
|
|
||||||
for uuid in sorted_uuids
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for uuid, result in zip(sorted_uuids, result_scores):
|
for result in results:
|
||||||
record = result[0][0]
|
scores[result['uuid']] = result['score']
|
||||||
scores[uuid] = record['score']
|
|
||||||
|
|
||||||
# rerank on shortest distance
|
# rerank on shortest distance
|
||||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue