diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index d5032807..86c27adb 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -23,6 +23,7 @@ from typing import Any import numpy as np from dotenv import load_dotenv from neo4j import time as neo4j_time +from numpy._typing import NDArray from typing_extensions import LiteralString load_dotenv() @@ -79,16 +80,10 @@ def lucene_sanitize(query: str) -> str: return sanitized -def normalize_l2(embedding: list[float]): +def normalize_l2(embedding: list[float]) -> NDArray: embedding_array = np.array(embedding) - if embedding_array.ndim == 1: - norm = np.linalg.norm(embedding_array) - if norm == 0: - return [0.0] * len(embedding) - return (embedding_array / norm).tolist() - else: - norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True) - return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist() + norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True) + return np.where(norm == 0, embedding_array, embedding_array / norm) # Use this instead of asyncio.gather() to bound coroutines diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 7460f1bb..1ed26420 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -50,6 +50,9 @@ from graphiti_core.search.search_utils import ( edge_similarity_search, episode_fulltext_search, episode_mentions_reranker, + get_embeddings_for_communities, + get_embeddings_for_edges, + get_embeddings_for_nodes, maximal_marginal_relevance, node_bfs_search, node_distance_reranker, @@ -209,26 +212,17 @@ async def edge_search( reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == EdgeReranker.mmr: - await semaphore_gather( - *[edge.load_fact_embedding(driver) for result in search_results for edge in result] + search_result_uuids_and_vectors = await get_embeddings_for_edges( + driver, list(edge_uuid_map.values()) ) - search_result_uuids_and_vectors = [ - (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024) - for result in search_results - for edge in result - ] reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda, + reranker_min_score, ) elif config.reranker == EdgeReranker.cross_encoder: - search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - - rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score) - rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] - - fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges} + fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]} reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys())) reranked_uuids = [ fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score @@ -311,30 +305,23 @@ async def node_search( if config.reranker == NodeReranker.rrf: reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == NodeReranker.mmr: - await semaphore_gather( - *[node.load_name_embedding(driver) for result in search_results for node in result] + search_result_uuids_and_vectors = await get_embeddings_for_nodes( + driver, list(node_uuid_map.values()) ) - search_result_uuids_and_vectors = [ - (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024) - for result in search_results - for node in result - ] + reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda, + reranker_min_score, ) elif config.reranker == NodeReranker.cross_encoder: - # use rrf as a preliminary reranker - rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score) - rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] + name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())} - summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results} - - reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) + reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys())) reranked_uuids = [ - summary_to_uuid_map[fact] - for fact, score in reranked_summaries + name_to_uuid_map[name] + for name, score in reranked_node_names if score >= reranker_min_score ] elif config.reranker == NodeReranker.episode_mentions: @@ -437,25 +424,12 @@ async def community_search( if config.reranker == CommunityReranker.rrf: reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == CommunityReranker.mmr: - await semaphore_gather( - *[ - community.load_name_embedding(driver) - for result in search_results - for community in result - ] + search_result_uuids_and_vectors = await get_embeddings_for_communities( + driver, list(community_uuid_map.values()) ) - search_result_uuids_and_vectors = [ - ( - community.uuid, - community.name_embedding if community.name_embedding is not None else [0.0] * 1024, - ) - for result in search_results - for community in result - ] + reranked_uuids = maximal_marginal_relevance( - query_vector, - search_result_uuids_and_vectors, - config.mmr_lambda, + query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score ) elif config.reranker == CommunityReranker.cross_encoder: name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result} diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 8619c6c8..ca24c903 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -21,6 +21,7 @@ from typing import Any import numpy as np from neo4j import AsyncDriver, Query +from numpy._typing import NDArray from typing_extensions import LiteralString from graphiti_core.edges import EntityEdge, get_entity_edge_from_record @@ -336,10 +337,10 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -899,6 +900,7 @@ async def node_distance_reranker( node_uuids=filtered_uuids, center_uuid=center_node_uuid, database_=DEFAULT_DATABASE, + routing_='r', ) for result in path_results: @@ -939,6 +941,7 @@ async def episode_mentions_reranker( query, node_uuids=sorted_uuids, database_=DEFAULT_DATABASE, + routing_='r', ) for result in results: @@ -952,15 +955,116 @@ async def episode_mentions_reranker( def maximal_marginal_relevance( query_vector: list[float], - candidates: list[tuple[str, list[float]]], + candidates: dict[str, list[float]], mmr_lambda: float = DEFAULT_MMR_LAMBDA, -): - candidates_with_mmr: list[tuple[str, float]] = [] - for candidate in candidates: - max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates]) - mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim - candidates_with_mmr.append((candidate[0], mmr)) + min_score: float = -2.0, +) -> list[str]: + start = time() + query_array = np.array(query_vector) + candidate_arrays: dict[str, NDArray] = {} + for uuid, embedding in candidates.items(): + candidate_arrays[uuid] = normalize_l2(embedding) - candidates_with_mmr.sort(reverse=True, key=lambda c: c[1]) + uuids: list[str] = list(candidate_arrays.keys()) - return list(set([candidate[0] for candidate in candidates_with_mmr])) + similarity_matrix = np.zeros((len(uuids), len(uuids))) + + for i, uuid_1 in enumerate(uuids): + for j, uuid_2 in enumerate(uuids[:i]): + u = candidate_arrays[uuid_1] + v = candidate_arrays[uuid_2] + similarity = np.dot(u, v) + + similarity_matrix[i, j] = similarity + similarity_matrix[j, i] = similarity + + mmr_scores: dict[str, float] = {} + for i, uuid in enumerate(uuids): + max_sim = np.max(similarity_matrix[i, :]) + mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim + mmr_scores[uuid] = mmr + + uuids.sort(reverse=True, key=lambda c: mmr_scores[c]) + + end = time() + logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms') + + return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score] + + +async def get_embeddings_for_nodes( + driver: AsyncDriver, nodes: list[EntityNode] +) -> dict[str, list[float]]: + query: LiteralString = """MATCH (n:Entity) + WHERE n.uuid IN $node_uuids + RETURN DISTINCT + n.uuid AS uuid, + n.name_embedding AS name_embedding + """ + + results, _, _ = await driver.execute_query( + query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r' + ) + + embeddings_dict: dict[str, list[float]] = {} + for result in results: + uuid: str = result.get('uuid') + embedding: list[float] = result.get('name_embedding') + if uuid is not None and embedding is not None: + embeddings_dict[uuid] = embedding + + return embeddings_dict + + +async def get_embeddings_for_communities( + driver: AsyncDriver, communities: list[CommunityNode] +) -> dict[str, list[float]]: + query: LiteralString = """MATCH (c:Community) + WHERE c.uuid IN $community_uuids + RETURN DISTINCT + c.uuid AS uuid, + c.name_embedding AS name_embedding + """ + + results, _, _ = await driver.execute_query( + query, + community_uuids=[community.uuid for community in communities], + database_=DEFAULT_DATABASE, + routing_='r', + ) + + embeddings_dict: dict[str, list[float]] = {} + for result in results: + uuid: str = result.get('uuid') + embedding: list[float] = result.get('name_embedding') + if uuid is not None and embedding is not None: + embeddings_dict[uuid] = embedding + + return embeddings_dict + + +async def get_embeddings_for_edges( + driver: AsyncDriver, edges: list[EntityEdge] +) -> dict[str, list[float]]: + query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) + WHERE e.uuid IN $edge_uuids + RETURN DISTINCT + e.uuid AS uuid, + e.fact_embedding AS fact_embedding + """ + + results, _, _ = await driver.execute_query( + query, + edge_uuids=[edge.uuid for edge in edges], + database_=DEFAULT_DATABASE, + routing_='r', + ) + + embeddings_dict: dict[str, list[float]] = {} + for result in results: + uuid: str = result.get('uuid') + embedding: list[float] = result.get('fact_embedding') + if uuid is not None and embedding is not None: + embeddings_dict[uuid] = embedding + + return embeddings_dict diff --git a/poetry.lock b/poetry.lock index f7df89f8..7f1cd5ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -333,7 +333,7 @@ description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -2668,7 +2668,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation != \"PyPy\"" files = [ {file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"}, {file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"}, @@ -4500,7 +4499,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, diff --git a/pyproject.toml b/pyproject.toml index 788f9631..ac73adfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.11.6pre7" +version = "0.11.6pre8" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 6739bf06..a0ab2877 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -65,9 +65,7 @@ async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - results = await graphiti.search_( - query='Who is the user?', - ) + results = await graphiti.search_(query='Who is the user?') pretty_results = search_results_to_context_string(results)