Mmr optimizations (#481)

* update mmr calculations

* update search

* fixes and updates

* mypy
This commit is contained in:
Preston Rasmussen 2025-05-12 22:30:23 -04:00 committed by GitHub
parent 4198483993
commit 9baa9b7b8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 146 additions and 76 deletions

View file

@ -23,6 +23,7 @@ from typing import Any
import numpy as np
from dotenv import load_dotenv
from neo4j import time as neo4j_time
from numpy._typing import NDArray
from typing_extensions import LiteralString
load_dotenv()
@ -79,16 +80,10 @@ def lucene_sanitize(query: str) -> str:
return sanitized
def normalize_l2(embedding: list[float]):
def normalize_l2(embedding: list[float]) -> NDArray:
embedding_array = np.array(embedding)
if embedding_array.ndim == 1:
norm = np.linalg.norm(embedding_array)
if norm == 0:
return [0.0] * len(embedding)
return (embedding_array / norm).tolist()
else:
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
return np.where(norm == 0, embedding_array, embedding_array / norm)
# Use this instead of asyncio.gather() to bound coroutines

View file

@ -50,6 +50,9 @@ from graphiti_core.search.search_utils import (
edge_similarity_search,
episode_fulltext_search,
episode_mentions_reranker,
get_embeddings_for_communities,
get_embeddings_for_edges,
get_embeddings_for_nodes,
maximal_marginal_relevance,
node_bfs_search,
node_distance_reranker,
@ -209,26 +212,17 @@ async def edge_search(
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr:
await semaphore_gather(
*[edge.load_fact_embedding(driver) for result in search_results for edge in result]
search_result_uuids_and_vectors = await get_embeddings_for_edges(
driver, list(edge_uuid_map.values())
)
search_result_uuids_and_vectors = [
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
for result in search_results
for edge in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
reranker_min_score,
)
elif config.reranker == EdgeReranker.cross_encoder:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
@ -311,30 +305,23 @@ async def node_search(
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr:
await semaphore_gather(
*[node.load_name_embedding(driver) for result in search_results for node in result]
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
driver, list(node_uuid_map.values())
)
search_result_uuids_and_vectors = [
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
for result in search_results
for node in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
reranker_min_score,
)
elif config.reranker == NodeReranker.cross_encoder:
# use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
reranked_uuids = [
summary_to_uuid_map[fact]
for fact, score in reranked_summaries
name_to_uuid_map[name]
for name, score in reranked_node_names
if score >= reranker_min_score
]
elif config.reranker == NodeReranker.episode_mentions:
@ -437,25 +424,12 @@ async def community_search(
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr:
await semaphore_gather(
*[
community.load_name_embedding(driver)
for result in search_results
for community in result
]
search_result_uuids_and_vectors = await get_embeddings_for_communities(
driver, list(community_uuid_map.values())
)
search_result_uuids_and_vectors = [
(
community.uuid,
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
)
for result in search_results
for community in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
)
elif config.reranker == CommunityReranker.cross_encoder:
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}

View file

@ -21,6 +21,7 @@ from typing import Any
import numpy as np
from neo4j import AsyncDriver, Query
from numpy._typing import NDArray
from typing_extensions import LiteralString
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
@ -336,10 +337,10 @@ async def node_fulltext_search(
query = (
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@ -899,6 +900,7 @@ async def node_distance_reranker(
node_uuids=filtered_uuids,
center_uuid=center_node_uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
for result in path_results:
@ -939,6 +941,7 @@ async def episode_mentions_reranker(
query,
node_uuids=sorted_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
for result in results:
@ -952,15 +955,116 @@ async def episode_mentions_reranker(
def maximal_marginal_relevance(
query_vector: list[float],
candidates: list[tuple[str, list[float]]],
candidates: dict[str, list[float]],
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
):
candidates_with_mmr: list[tuple[str, float]] = []
for candidate in candidates:
max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
candidates_with_mmr.append((candidate[0], mmr))
min_score: float = -2.0,
) -> list[str]:
start = time()
query_array = np.array(query_vector)
candidate_arrays: dict[str, NDArray] = {}
for uuid, embedding in candidates.items():
candidate_arrays[uuid] = normalize_l2(embedding)
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
uuids: list[str] = list(candidate_arrays.keys())
return list(set([candidate[0] for candidate in candidates_with_mmr]))
similarity_matrix = np.zeros((len(uuids), len(uuids)))
for i, uuid_1 in enumerate(uuids):
for j, uuid_2 in enumerate(uuids[:i]):
u = candidate_arrays[uuid_1]
v = candidate_arrays[uuid_2]
similarity = np.dot(u, v)
similarity_matrix[i, j] = similarity
similarity_matrix[j, i] = similarity
mmr_scores: dict[str, float] = {}
for i, uuid in enumerate(uuids):
max_sim = np.max(similarity_matrix[i, :])
mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
mmr_scores[uuid] = mmr
uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
end = time()
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
async def get_embeddings_for_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (n:Entity)
WHERE n.uuid IN $node_uuids
RETURN DISTINCT
n.uuid AS uuid,
n.name_embedding AS name_embedding
"""
results, _, _ = await driver.execute_query(
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
)
embeddings_dict: dict[str, list[float]] = {}
for result in results:
uuid: str = result.get('uuid')
embedding: list[float] = result.get('name_embedding')
if uuid is not None and embedding is not None:
embeddings_dict[uuid] = embedding
return embeddings_dict
async def get_embeddings_for_communities(
driver: AsyncDriver, communities: list[CommunityNode]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (c:Community)
WHERE c.uuid IN $community_uuids
RETURN DISTINCT
c.uuid AS uuid,
c.name_embedding AS name_embedding
"""
results, _, _ = await driver.execute_query(
query,
community_uuids=[community.uuid for community in communities],
database_=DEFAULT_DATABASE,
routing_='r',
)
embeddings_dict: dict[str, list[float]] = {}
for result in results:
uuid: str = result.get('uuid')
embedding: list[float] = result.get('name_embedding')
if uuid is not None and embedding is not None:
embeddings_dict[uuid] = embedding
return embeddings_dict
async def get_embeddings_for_edges(
driver: AsyncDriver, edges: list[EntityEdge]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
WHERE e.uuid IN $edge_uuids
RETURN DISTINCT
e.uuid AS uuid,
e.fact_embedding AS fact_embedding
"""
results, _, _ = await driver.execute_query(
query,
edge_uuids=[edge.uuid for edge in edges],
database_=DEFAULT_DATABASE,
routing_='r',
)
embeddings_dict: dict[str, list[float]] = {}
for result in results:
uuid: str = result.get('uuid')
embedding: list[float] = result.get('fact_embedding')
if uuid is not None and embedding is not None:
embeddings_dict[uuid] = embedding
return embeddings_dict

9
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@ -333,7 +333,7 @@ description = "Timeout context manager for asyncio programs"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
@ -2668,7 +2668,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
optional = false
python-versions = ">=3.9"
groups = ["dev"]
markers = "platform_python_implementation != \"PyPy\""
files = [
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
@ -4500,7 +4499,7 @@ description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.11.6pre7"
version = "0.11.6pre8"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

View file

@ -65,9 +65,7 @@ async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
results = await graphiti.search_(
query='Who is the user?',
)
results = await graphiti.search_(query='Who is the user?')
pretty_results = search_results_to_context_string(results)