Mmr optimizations (#481)
* update mmr calculations * update search * fixes and updates * mypy
This commit is contained in:
parent
4198483993
commit
9baa9b7b8a
6 changed files with 146 additions and 76 deletions
|
|
@ -23,6 +23,7 @@ from typing import Any
|
|||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from neo4j import time as neo4j_time
|
||||
from numpy._typing import NDArray
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -79,16 +80,10 @@ def lucene_sanitize(query: str) -> str:
|
|||
return sanitized
|
||||
|
||||
|
||||
def normalize_l2(embedding: list[float]):
|
||||
def normalize_l2(embedding: list[float]) -> NDArray:
|
||||
embedding_array = np.array(embedding)
|
||||
if embedding_array.ndim == 1:
|
||||
norm = np.linalg.norm(embedding_array)
|
||||
if norm == 0:
|
||||
return [0.0] * len(embedding)
|
||||
return (embedding_array / norm).tolist()
|
||||
else:
|
||||
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
||||
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
||||
norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
|
||||
return np.where(norm == 0, embedding_array, embedding_array / norm)
|
||||
|
||||
|
||||
# Use this instead of asyncio.gather() to bound coroutines
|
||||
|
|
|
|||
|
|
@ -50,6 +50,9 @@ from graphiti_core.search.search_utils import (
|
|||
edge_similarity_search,
|
||||
episode_fulltext_search,
|
||||
episode_mentions_reranker,
|
||||
get_embeddings_for_communities,
|
||||
get_embeddings_for_edges,
|
||||
get_embeddings_for_nodes,
|
||||
maximal_marginal_relevance,
|
||||
node_bfs_search,
|
||||
node_distance_reranker,
|
||||
|
|
@ -209,26 +212,17 @@ async def edge_search(
|
|||
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == EdgeReranker.mmr:
|
||||
await semaphore_gather(
|
||||
*[edge.load_fact_embedding(driver) for result in search_results for edge in result]
|
||||
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
||||
driver, list(edge_uuid_map.values())
|
||||
)
|
||||
search_result_uuids_and_vectors = [
|
||||
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
||||
for result in search_results
|
||||
for edge in result
|
||||
]
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
reranker_min_score,
|
||||
)
|
||||
elif config.reranker == EdgeReranker.cross_encoder:
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
|
||||
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
||||
fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
|
||||
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
||||
reranked_uuids = [
|
||||
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
||||
|
|
@ -311,30 +305,23 @@ async def node_search(
|
|||
if config.reranker == NodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == NodeReranker.mmr:
|
||||
await semaphore_gather(
|
||||
*[node.load_name_embedding(driver) for result in search_results for node in result]
|
||||
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
||||
driver, list(node_uuid_map.values())
|
||||
)
|
||||
search_result_uuids_and_vectors = [
|
||||
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
||||
for result in search_results
|
||||
for node in result
|
||||
]
|
||||
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
reranker_min_score,
|
||||
)
|
||||
elif config.reranker == NodeReranker.cross_encoder:
|
||||
# use rrf as a preliminary reranker
|
||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
|
||||
|
||||
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
|
||||
|
||||
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||
reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
|
||||
reranked_uuids = [
|
||||
summary_to_uuid_map[fact]
|
||||
for fact, score in reranked_summaries
|
||||
name_to_uuid_map[name]
|
||||
for name, score in reranked_node_names
|
||||
if score >= reranker_min_score
|
||||
]
|
||||
elif config.reranker == NodeReranker.episode_mentions:
|
||||
|
|
@ -437,25 +424,12 @@ async def community_search(
|
|||
if config.reranker == CommunityReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == CommunityReranker.mmr:
|
||||
await semaphore_gather(
|
||||
*[
|
||||
community.load_name_embedding(driver)
|
||||
for result in search_results
|
||||
for community in result
|
||||
]
|
||||
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
||||
driver, list(community_uuid_map.values())
|
||||
)
|
||||
search_result_uuids_and_vectors = [
|
||||
(
|
||||
community.uuid,
|
||||
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
|
||||
)
|
||||
for result in search_results
|
||||
for community in result
|
||||
]
|
||||
|
||||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector,
|
||||
search_result_uuids_and_vectors,
|
||||
config.mmr_lambda,
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
||||
)
|
||||
elif config.reranker == CommunityReranker.cross_encoder:
|
||||
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from typing import Any
|
|||
|
||||
import numpy as np
|
||||
from neo4j import AsyncDriver, Query
|
||||
from numpy._typing import NDArray
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
|
|
@ -336,10 +337,10 @@ async def node_fulltext_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
|
|
@ -899,6 +900,7 @@ async def node_distance_reranker(
|
|||
node_uuids=filtered_uuids,
|
||||
center_uuid=center_node_uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
for result in path_results:
|
||||
|
|
@ -939,6 +941,7 @@ async def episode_mentions_reranker(
|
|||
query,
|
||||
node_uuids=sorted_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
for result in results:
|
||||
|
|
@ -952,15 +955,116 @@ async def episode_mentions_reranker(
|
|||
|
||||
def maximal_marginal_relevance(
|
||||
query_vector: list[float],
|
||||
candidates: list[tuple[str, list[float]]],
|
||||
candidates: dict[str, list[float]],
|
||||
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
||||
):
|
||||
candidates_with_mmr: list[tuple[str, float]] = []
|
||||
for candidate in candidates:
|
||||
max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
|
||||
mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
|
||||
candidates_with_mmr.append((candidate[0], mmr))
|
||||
min_score: float = -2.0,
|
||||
) -> list[str]:
|
||||
start = time()
|
||||
query_array = np.array(query_vector)
|
||||
candidate_arrays: dict[str, NDArray] = {}
|
||||
for uuid, embedding in candidates.items():
|
||||
candidate_arrays[uuid] = normalize_l2(embedding)
|
||||
|
||||
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
||||
uuids: list[str] = list(candidate_arrays.keys())
|
||||
|
||||
return list(set([candidate[0] for candidate in candidates_with_mmr]))
|
||||
similarity_matrix = np.zeros((len(uuids), len(uuids)))
|
||||
|
||||
for i, uuid_1 in enumerate(uuids):
|
||||
for j, uuid_2 in enumerate(uuids[:i]):
|
||||
u = candidate_arrays[uuid_1]
|
||||
v = candidate_arrays[uuid_2]
|
||||
similarity = np.dot(u, v)
|
||||
|
||||
similarity_matrix[i, j] = similarity
|
||||
similarity_matrix[j, i] = similarity
|
||||
|
||||
mmr_scores: dict[str, float] = {}
|
||||
for i, uuid in enumerate(uuids):
|
||||
max_sim = np.max(similarity_matrix[i, :])
|
||||
mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
|
||||
mmr_scores[uuid] = mmr
|
||||
|
||||
uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
|
||||
|
||||
end = time()
|
||||
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
||||
|
||||
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
||||
|
||||
|
||||
async def get_embeddings_for_nodes(
|
||||
driver: AsyncDriver, nodes: list[EntityNode]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (n:Entity)
|
||||
WHERE n.uuid IN $node_uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid AS uuid,
|
||||
n.name_embedding AS name_embedding
|
||||
"""
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
|
||||
embeddings_dict: dict[str, list[float]] = {}
|
||||
for result in results:
|
||||
uuid: str = result.get('uuid')
|
||||
embedding: list[float] = result.get('name_embedding')
|
||||
if uuid is not None and embedding is not None:
|
||||
embeddings_dict[uuid] = embedding
|
||||
|
||||
return embeddings_dict
|
||||
|
||||
|
||||
async def get_embeddings_for_communities(
|
||||
driver: AsyncDriver, communities: list[CommunityNode]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (c:Community)
|
||||
WHERE c.uuid IN $community_uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid AS uuid,
|
||||
c.name_embedding AS name_embedding
|
||||
"""
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
community_uuids=[community.uuid for community in communities],
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
embeddings_dict: dict[str, list[float]] = {}
|
||||
for result in results:
|
||||
uuid: str = result.get('uuid')
|
||||
embedding: list[float] = result.get('name_embedding')
|
||||
if uuid is not None and embedding is not None:
|
||||
embeddings_dict[uuid] = embedding
|
||||
|
||||
return embeddings_dict
|
||||
|
||||
|
||||
async def get_embeddings_for_edges(
|
||||
driver: AsyncDriver, edges: list[EntityEdge]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||
WHERE e.uuid IN $edge_uuids
|
||||
RETURN DISTINCT
|
||||
e.uuid AS uuid,
|
||||
e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
edge_uuids=[edge.uuid for edge in edges],
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
embeddings_dict: dict[str, list[float]] = {}
|
||||
for result in results:
|
||||
uuid: str = result.get('uuid')
|
||||
embedding: list[float] = result.get('fact_embedding')
|
||||
if uuid is not None and embedding is not None:
|
||||
embeddings_dict[uuid] = embedding
|
||||
|
||||
return embeddings_dict
|
||||
|
|
|
|||
9
poetry.lock
generated
9
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
|
|
@ -333,7 +333,7 @@ description = "Timeout context manager for asyncio programs"
|
|||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
|
||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||
|
|
@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)"
|
|||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
||||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||
|
|
@ -2668,7 +2668,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
|
|||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
|
||||
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
|
||||
|
|
@ -4500,7 +4499,7 @@ description = "A lil' TOML parser"
|
|||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.11.6pre7"
|
||||
version = "0.11.6pre8"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -65,9 +65,7 @@ async def test_graphiti_init():
|
|||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
|
||||
results = await graphiti.search_(
|
||||
query='Who is the user?',
|
||||
)
|
||||
results = await graphiti.search_(query='Who is the user?')
|
||||
|
||||
pretty_results = search_results_to_context_string(results)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue