From d7c20c1f59de9edb345e9a8523efe0ec8e6a949b Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 16 Sep 2024 14:03:05 -0400 Subject: [PATCH] Search refactor + Community search (#111) * WIP * WIP * WIP * community search * WIP * WIP * integration tested * tests * tests * mypy * mypy * format --- examples/podcast/podcast_runner.py | 2 +- graphiti_core/errors.py | 25 + graphiti_core/graphiti.py | 58 ++- graphiti_core/helpers.py | 16 + graphiti_core/llm_client/errors.py | 17 + graphiti_core/llm_client/utils.py | 18 +- graphiti_core/prompts/extract_edge_dates.py | 16 + graphiti_core/search/search.py | 284 +++++++---- graphiti_core/search/search_config.py | 81 +++ graphiti_core/search/search_config_recipes.py | 84 ++++ graphiti_core/search/search_utils.py | 466 ++++++++++-------- tests/test_graphiti_int.py | 12 + tests/utils/search/search_utils_test.py | 30 +- 13 files changed, 780 insertions(+), 329 deletions(-) create mode 100644 graphiti_core/search/search_config.py create mode 100644 graphiti_core/search/search_config_recipes.py diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 792d5720..37ec9345 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True): messages = parse_podcast_messages() if not use_bulk: - for i, message in enumerate(messages[3:130]): + for i, message in enumerate(messages[3:20]): await client.add_episode( name=f'Message {i}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}', diff --git a/graphiti_core/errors.py b/graphiti_core/errors.py index e6da50d0..84737419 100644 --- a/graphiti_core/errors.py +++ b/graphiti_core/errors.py @@ -1,3 +1,20 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + class GraphitiError(Exception): """Base exception class for Graphiti Core.""" @@ -16,3 +33,11 @@ class NodeNotFoundError(GraphitiError): def __init__(self, uuid: str): self.message = f'node {uuid} not found' super().__init__(self.message) + + +class SearchRerankerError(GraphitiError): + """Raised when a node is not found.""" + + def __init__(self, text: str): + self.message = text + super().__init__(self.message) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 4536f75e..f388dccc 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.llm_client import LLMClient, OpenAIClient -from graphiti_core.llm_client.utils import generate_embedding from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode -from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search +from graphiti_core.search.search import SearchConfig, search +from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults +from graphiti_core.search.search_config_recipes import ( + EDGE_HYBRID_SEARCH_NODE_DISTANCE, + EDGE_HYBRID_SEARCH_RRF, + NODE_HYBRID_SEARCH_NODE_DISTANCE, + NODE_HYBRID_SEARCH_RRF, +) from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, get_relevant_edges, get_relevant_nodes, - hybrid_node_search, ) from graphiti_core.utils import ( build_episodic_edges, @@ -548,7 +553,7 @@ class Graphiti: query: str, center_node_uuid: str | None = None, group_ids: list[str | None] | None = None, - num_results=10, + num_results=DEFAULT_SEARCH_LIMIT, ): """ Perform a hybrid search on the knowledge graph. @@ -564,7 +569,7 @@ class Graphiti: Facts will be reranked based on proximity to this node group_ids : list[str | None] | None, optional The graph partitions to return data from. - num_results : int, optional + limit : int, optional The maximum number of results to return. Defaults to 10. Returns @@ -581,21 +586,17 @@ class Graphiti: The search is performed using the current date and time as the reference point for temporal relevance. """ - reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance - search_config = SearchConfig( - num_episodes=0, - num_edges=num_results, - num_nodes=0, - group_ids=group_ids, - search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity], - reranker=reranker, + search_config = ( + EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE ) + search_config.limit = num_results + edges = ( - await hybrid_search( + await search( self.driver, self.llm_client.get_embedder(), query, - datetime.now(), + group_ids, search_config, center_node_uuid, ) @@ -606,19 +607,20 @@ class Graphiti: async def _search( self, query: str, - timestamp: datetime, config: SearchConfig, + group_ids: list[str | None] | None = None, center_node_uuid: str | None = None, - ): - return await hybrid_search( - self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid + ) -> SearchResults: + return await search( + self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid ) async def get_nodes_by_query( self, query: str, + center_node_uuid: str | None = None, group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + limit: int = DEFAULT_SEARCH_LIMIT, ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. @@ -629,7 +631,9 @@ class Graphiti: Parameters ---------- query : str - The text query to search for in the graph. + The text query to search for in the graph + center_node_uuid: str, optional + Facts will be reranked based on proximity to this node. group_ids : list[str | None] | None, optional The graph partitions to return data from. limit : int | None, optional @@ -655,8 +659,12 @@ class Graphiti: If not specified, a default limit (defined in the search functions) will be used. """ embedder = self.llm_client.get_embedder() - query_embedding = await generate_embedding(embedder, query) - relevant_nodes = await hybrid_node_search( - [query], [query_embedding], self.driver, group_ids, limit + search_config = ( + NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE ) - return relevant_nodes + search_config.limit = limit + + nodes = ( + await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid) + ).nodes + return nodes diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 6233d274..9471058e 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + from datetime import datetime from neo4j import time as neo4j_time diff --git a/graphiti_core/llm_client/errors.py b/graphiti_core/llm_client/errors.py index 13f9b479..0c0f5dd1 100644 --- a/graphiti_core/llm_client/errors.py +++ b/graphiti_core/llm_client/errors.py @@ -1,3 +1,20 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + class RateLimitError(Exception): """Exception raised when the rate limit is exceeded.""" diff --git a/graphiti_core/llm_client/utils.py b/graphiti_core/llm_client/utils.py index d8740137..d98b49b9 100644 --- a/graphiti_core/llm_client/utils.py +++ b/graphiti_core/llm_client/utils.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import logging import typing from time import time @@ -17,6 +33,6 @@ async def generate_embedding( embedding = embedding[:EMBEDDING_DIM] end = time() - logger.debug(f'embedded text of length {len(text)} in {end-start} ms') + logger.debug(f'embedded text of length {len(text)} in {end - start} ms') return embedding diff --git a/graphiti_core/prompts/extract_edge_dates.py b/graphiti_core/prompts/extract_edge_dates.py index ef6ffc12..4d6ab851 100644 --- a/graphiti_core/prompts/extract_edge_dates.py +++ b/graphiti_core/prompts/extract_edge_dates.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + from typing import Any, Protocol, TypedDict from .models import Message, PromptFunction, PromptVersion diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 3e4c59f1..210586e5 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -15,131 +15,227 @@ limitations under the License. """ import logging -from datetime import datetime -from enum import Enum from time import time from neo4j import AsyncDriver -from pydantic import BaseModel, Field from graphiti_core.edges import EntityEdge +from graphiti_core.errors import SearchRerankerError from graphiti_core.llm_client.config import EMBEDDING_DIM -from graphiti_core.nodes import EntityNode, EpisodicNode +from graphiti_core.nodes import CommunityNode, EntityNode +from graphiti_core.search.search_config import ( + DEFAULT_SEARCH_LIMIT, + CommunityReranker, + CommunitySearchConfig, + CommunitySearchMethod, + EdgeReranker, + EdgeSearchConfig, + EdgeSearchMethod, + NodeReranker, + NodeSearchConfig, + NodeSearchMethod, + SearchConfig, + SearchResults, +) from graphiti_core.search.search_utils import ( + community_fulltext_search, + community_similarity_search, edge_fulltext_search, edge_similarity_search, - get_mentioned_nodes, node_distance_reranker, + node_fulltext_search, + node_similarity_search, rrf, ) -from graphiti_core.utils import retrieve_episodes -from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN logger = logging.getLogger(__name__) -class SearchMethod(Enum): - cosine_similarity = 'cosine_similarity' - bm25 = 'bm25' - - -class Reranker(Enum): - rrf = 'reciprocal_rank_fusion' - node_distance = 'node_distance' - - -class SearchConfig(BaseModel): - num_edges: int = Field(default=10) - num_nodes: int = Field(default=10) - num_episodes: int = EPISODE_WINDOW_LEN - group_ids: list[str | None] | None - search_methods: list[SearchMethod] - reranker: Reranker | None - - -class SearchResults(BaseModel): - episodes: list[EpisodicNode] - nodes: list[EntityNode] - edges: list[EntityEdge] - - -async def hybrid_search( +async def search( driver: AsyncDriver, embedder, query: str, - timestamp: datetime, + group_ids: list[str | None] | None, config: SearchConfig, center_node_uuid: str | None = None, ) -> SearchResults: start = time() + query = query.replace('\n', ' ') - episodes = [] - nodes = [] - edges = [] - - search_results = [] - - if config.num_episodes > 0: - episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes)) - nodes.extend(await get_mentioned_nodes(driver, episodes)) - - if SearchMethod.bm25 in config.search_methods: - text_search = await edge_fulltext_search( - driver, query, None, None, config.group_ids, 2 * config.num_edges + edges = ( + await edge_search( + driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit ) - search_results.append(text_search) - - if SearchMethod.cosine_similarity in config.search_methods: - query_text = query.replace('\n', ' ') - search_vector = ( - (await embedder.create(input=[query_text], model='text-embedding-3-small')) - .data[0] - .embedding[:EMBEDDING_DIM] + if config.edge_config is not None + else [] + ) + nodes = ( + await node_search( + driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit ) - - similarity_search = await edge_similarity_search( - driver, search_vector, None, None, config.group_ids, 2 * config.num_edges + if config.node_config is not None + else [] + ) + communities = ( + await community_search( + driver, embedder, query, group_ids, config.community_config, config.limit ) - search_results.append(similarity_search) + if config.community_config is not None + else [] + ) - if len(search_results) > 1 and config.reranker is None: - logger.exception('Multiple searches enabled without a reranker') - raise Exception('Multiple searches enabled without a reranker') - - else: - edge_uuid_map = {} - search_result_uuids = [] - - for result in search_results: - result_uuids = [] - for edge in result: - result_uuids.append(edge.uuid) - edge_uuid_map[edge.uuid] = edge - - search_result_uuids.append(result_uuids) - - search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - - reranked_uuids: list[str] = [] - if config.reranker == Reranker.rrf: - reranked_uuids = rrf(search_result_uuids) - elif config.reranker == Reranker.node_distance: - if center_node_uuid is None: - logger.exception('No center node provided for Node Distance reranker') - raise Exception('No center node provided for Node Distance reranker') - reranked_uuids = await node_distance_reranker( - driver, search_result_uuids, center_node_uuid - ) - - reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] - edges.extend(reranked_edges) - - context = SearchResults( - episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges] + results = SearchResults( + edges=edges[: config.limit], + nodes=nodes[: config.limit], + communities=communities[: config.limit], ) end = time() logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms') - return context + return results + + +async def edge_search( + driver: AsyncDriver, + embedder, + query: str, + group_ids: list[str | None] | None, + config: EdgeSearchConfig, + center_node_uuid: str | None = None, + limit=DEFAULT_SEARCH_LIMIT, +) -> list[EntityEdge]: + search_results: list[list[EntityEdge]] = [] + + if EdgeSearchMethod.bm25 in config.search_methods: + text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit) + search_results.append(text_search) + + if EdgeSearchMethod.cosine_similarity in config.search_methods: + search_vector = ( + (await embedder.create(input=[query], model='text-embedding-3-small')) + .data[0] + .embedding[:EMBEDDING_DIM] + ) + + similarity_search = await edge_similarity_search( + driver, search_vector, None, None, group_ids, 2 * limit + ) + search_results.append(similarity_search) + + if len(search_results) > 1 and config.reranker is None: + raise SearchRerankerError('Multiple edge searches enabled without a reranker') + + edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} + + reranked_uuids: list[str] = [] + if config.reranker == EdgeReranker.rrf: + search_result_uuids = [[edge.uuid for edge in result] for result in search_results] + + reranked_uuids = rrf(search_result_uuids) + elif config.reranker == EdgeReranker.node_distance: + if center_node_uuid is None: + raise SearchRerankerError('No center node provided for Node Distance reranker') + + source_to_edge_uuid_map = { + edge.source_node_uuid: edge.uuid for result in search_results for edge in result + } + source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results] + + reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) + + reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids] + + reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] + + return reranked_edges + + +async def node_search( + driver: AsyncDriver, + embedder, + query: str, + group_ids: list[str | None] | None, + config: NodeSearchConfig, + center_node_uuid: str | None = None, + limit=DEFAULT_SEARCH_LIMIT, +) -> list[EntityNode]: + search_results: list[list[EntityNode]] = [] + + if NodeSearchMethod.bm25 in config.search_methods: + text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit) + search_results.append(text_search) + + if NodeSearchMethod.cosine_similarity in config.search_methods: + search_vector = ( + (await embedder.create(input=[query], model='text-embedding-3-small')) + .data[0] + .embedding[:EMBEDDING_DIM] + ) + + similarity_search = await node_similarity_search( + driver, search_vector, group_ids, 2 * limit + ) + search_results.append(similarity_search) + + if len(search_results) > 1 and config.reranker is None: + raise SearchRerankerError('Multiple node searches enabled without a reranker') + + search_result_uuids = [[node.uuid for node in result] for result in search_results] + node_uuid_map = {node.uuid: node for result in search_results for node in result} + + reranked_uuids: list[str] = [] + if config.reranker == NodeReranker.rrf: + reranked_uuids = rrf(search_result_uuids) + elif config.reranker == NodeReranker.node_distance: + if center_node_uuid is None: + raise SearchRerankerError('No center node provided for Node Distance reranker') + reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid) + + reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] + + return reranked_nodes + + +async def community_search( + driver: AsyncDriver, + embedder, + query: str, + group_ids: list[str | None] | None, + config: CommunitySearchConfig, + limit=DEFAULT_SEARCH_LIMIT, +) -> list[CommunityNode]: + search_results: list[list[CommunityNode]] = [] + + if CommunitySearchMethod.bm25 in config.search_methods: + text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit) + search_results.append(text_search) + + if CommunitySearchMethod.cosine_similarity in config.search_methods: + search_vector = ( + (await embedder.create(input=[query], model='text-embedding-3-small')) + .data[0] + .embedding[:EMBEDDING_DIM] + ) + + similarity_search = await community_similarity_search( + driver, search_vector, group_ids, 2 * limit + ) + search_results.append(similarity_search) + + if len(search_results) > 1 and config.reranker is None: + raise SearchRerankerError('Multiple node searches enabled without a reranker') + + search_result_uuids = [[community.uuid for community in result] for result in search_results] + community_uuid_map = { + community.uuid: community for result in search_results for community in result + } + + reranked_uuids: list[str] = [] + if config.reranker == CommunityReranker.rrf: + reranked_uuids = rrf(search_result_uuids) + + reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] + + return reranked_communities diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py new file mode 100644 index 00000000..3bd6b6cb --- /dev/null +++ b/graphiti_core/search/search_config.py @@ -0,0 +1,81 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from enum import Enum + +from pydantic import BaseModel, Field + +from graphiti_core.edges import EntityEdge +from graphiti_core.nodes import CommunityNode, EntityNode + +DEFAULT_SEARCH_LIMIT = 10 + + +class EdgeSearchMethod(Enum): + cosine_similarity = 'cosine_similarity' + bm25 = 'bm25' + + +class NodeSearchMethod(Enum): + cosine_similarity = 'cosine_similarity' + bm25 = 'bm25' + + +class CommunitySearchMethod(Enum): + cosine_similarity = 'cosine_similarity' + bm25 = 'bm25' + + +class EdgeReranker(Enum): + rrf = 'reciprocal_rank_fusion' + node_distance = 'node_distance' + + +class NodeReranker(Enum): + rrf = 'reciprocal_rank_fusion' + node_distance = 'node_distance' + + +class CommunityReranker(Enum): + rrf = 'reciprocal_rank_fusion' + + +class EdgeSearchConfig(BaseModel): + search_methods: list[EdgeSearchMethod] + reranker: EdgeReranker | None + + +class NodeSearchConfig(BaseModel): + search_methods: list[NodeSearchMethod] + reranker: NodeReranker | None + + +class CommunitySearchConfig(BaseModel): + search_methods: list[CommunitySearchMethod] + reranker: CommunityReranker | None + + +class SearchConfig(BaseModel): + edge_config: EdgeSearchConfig | None = Field(default=None) + node_config: NodeSearchConfig | None = Field(default=None) + community_config: CommunitySearchConfig | None = Field(default=None) + limit: int = Field(default=DEFAULT_SEARCH_LIMIT) + + +class SearchResults(BaseModel): + edges: list[EntityEdge] + nodes: list[EntityNode] + communities: list[CommunityNode] diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py new file mode 100644 index 00000000..5aa30198 --- /dev/null +++ b/graphiti_core/search/search_config_recipes.py @@ -0,0 +1,84 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from graphiti_core.search.search_config import ( + CommunityReranker, + CommunitySearchConfig, + CommunitySearchMethod, + EdgeReranker, + EdgeSearchConfig, + EdgeSearchMethod, + NodeReranker, + NodeSearchConfig, + NodeSearchMethod, + SearchConfig, +) + +# Performs a hybrid search with rrf reranking over edges, nodes, and communities +COMBINED_HYBRID_SEARCH_RRF = SearchConfig( + edge_config=EdgeSearchConfig( + search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], + reranker=EdgeReranker.rrf, + ), + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.rrf, + ), + community_config=CommunitySearchConfig( + search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], + reranker=CommunityReranker.rrf, + ), +) + +# performs a hybrid search over edges with rrf reranking +EDGE_HYBRID_SEARCH_RRF = SearchConfig( + edge_config=EdgeSearchConfig( + search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], + reranker=EdgeReranker.rrf, + ) +) + +# performs a hybrid search over edges with node distance reranking +EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( + edge_config=EdgeSearchConfig( + search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], + reranker=EdgeReranker.node_distance, + ) +) + +# performs a hybrid search over nodes with rrf reranking +NODE_HYBRID_SEARCH_RRF = SearchConfig( + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.rrf, + ) +) + +# performs a hybrid search over nodes with node distance reranking +NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.node_distance, + ) +) + +# performs a hybrid search over communities with rrf reranking +COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig( + community_config=CommunitySearchConfig( + search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], + reranker=CommunityReranker.rrf, + ) +) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 38f3bd6e..702bd6e5 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import asyncio import logging import re @@ -7,7 +23,13 @@ from time import time from neo4j import AsyncDriver, Query from graphiti_core.edges import EntityEdge, get_entity_edge_from_record -from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record +from graphiti_core.nodes import ( + CommunityNode, + EntityNode, + EpisodicNode, + get_community_node_from_record, + get_entity_node_from_record, +) logger = logging.getLogger(__name__) @@ -35,181 +57,6 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) return nodes -async def edge_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - source_node_uuid: str | None, - target_node_uuid: str | None, - group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, -) -> list[EntityEdge]: - group_ids = group_ids if group_ids is not None else [None] - # vector similarity search over embedded facts - query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC - """) - - if source_node_uuid is None and target_node_uuid is None: - query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) - WHERE r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC - """) - elif source_node_uuid is None: - query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC - """) - elif target_node_uuid is None: - query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) - WHERE r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC - """) - - records, _, _ = await driver.execute_query( - query, - search_vector=search_vector, - source_uuid=source_node_uuid, - target_uuid=target_node_uuid, - group_ids=group_ids, - limit=limit, - ) - - edges = [get_entity_edge_from_record(record) for record in records] - - return edges - - -async def entity_similarity_search( - search_vector: list[float], - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, -) -> list[EntityNode]: - group_ids = group_ids if group_ids is not None else [None] - - # vector similarity search over entity names - records, _, _ = await driver.execute_query( - """ - CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) - YIELD node AS n, score - MATCH (n WHERE n.group_id IN $group_ids) - RETURN - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary - ORDER BY score DESC - """, - search_vector=search_vector, - group_ids=group_ids, - limit=limit, - ) - nodes = [get_entity_node_from_record(record) for record in records] - - return nodes - - -async def entity_fulltext_search( - query: str, - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, -) -> list[EntityNode]: - group_ids = group_ids if group_ids is not None else [None] - - # BM25 search to get top nodes - fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' - records, _, _ = await driver.execute_query( - """ - CALL db.index.fulltext.queryNodes("name_and_summary", $query) - YIELD node AS n, score - MATCH (n WHERE n.group_id in $group_ids) - RETURN - n.uuid AS uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary - ORDER BY score DESC - LIMIT $limit - """, - query=fuzzy_query, - group_ids=group_ids, - limit=limit, - ) - nodes = [get_entity_node_from_record(record) for record in records] - - return nodes - - async def edge_fulltext_search( driver: AsyncDriver, query: str, @@ -322,6 +169,247 @@ async def edge_fulltext_search( return edges +async def edge_similarity_search( + driver: AsyncDriver, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, +) -> list[EntityEdge]: + group_ids = group_ids if group_ids is not None else [None] + # vector similarity search over embedded facts + query = Query(""" + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) + YIELD relationship AS rel, score + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + """) + + if source_node_uuid is None and target_node_uuid is None: + query = Query(""" + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) + YIELD relationship AS rel, score + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + """) + elif source_node_uuid is None: + query = Query(""" + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) + YIELD relationship AS rel, score + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + """) + elif target_node_uuid is None: + query = Query(""" + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) + YIELD relationship AS rel, score + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + """) + + records, _, _ = await driver.execute_query( + query, + search_vector=search_vector, + source_uuid=source_node_uuid, + target_uuid=target_node_uuid, + group_ids=group_ids, + limit=limit, + ) + + edges = [get_entity_edge_from_record(record) for record in records] + + return edges + + +async def node_fulltext_search( + driver: AsyncDriver, + query: str, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, +) -> list[EntityNode]: + group_ids = group_ids if group_ids is not None else [None] + + # BM25 search to get top nodes + fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' + records, _, _ = await driver.execute_query( + """ + CALL db.index.fulltext.queryNodes("name_and_summary", $query) + YIELD node AS n, score + MATCH (n WHERE n.group_id in $group_ids) + RETURN + n.uuid AS uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary + ORDER BY score DESC + LIMIT $limit + """, + query=fuzzy_query, + group_ids=group_ids, + limit=limit, + ) + nodes = [get_entity_node_from_record(record) for record in records] + + return nodes + + +async def node_similarity_search( + driver: AsyncDriver, + search_vector: list[float], + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, +) -> list[EntityNode]: + group_ids = group_ids if group_ids is not None else [None] + + # vector similarity search over entity names + records, _, _ = await driver.execute_query( + """ + CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) + YIELD node AS n, score + MATCH (n WHERE n.group_id IN $group_ids) + RETURN + n.uuid As uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary + ORDER BY score DESC + """, + search_vector=search_vector, + group_ids=group_ids, + limit=limit, + ) + nodes = [get_entity_node_from_record(record) for record in records] + + return nodes + + +async def community_fulltext_search( + driver: AsyncDriver, + query: str, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, +) -> list[CommunityNode]: + group_ids = group_ids if group_ids is not None else [None] + + # BM25 search to get top communities + fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' + records, _, _ = await driver.execute_query( + """ + CALL db.index.fulltext.queryNodes("community_name", $query) + YIELD node AS comm, score + MATCH (comm WHERE comm.group_id in $group_ids) + RETURN + comm.uuid AS uuid, + comm.group_id AS group_id, + comm.name AS name, + comm.name_embedding AS name_embedding, + comm.created_at AS created_at, + comm.summary AS summary + ORDER BY score DESC + LIMIT $limit + """, + query=fuzzy_query, + group_ids=group_ids, + limit=limit, + ) + communities = [get_community_node_from_record(record) for record in records] + + return communities + + +async def community_similarity_search( + driver: AsyncDriver, + search_vector: list[float], + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, +) -> list[CommunityNode]: + group_ids = group_ids if group_ids is not None else [None] + + # vector similarity search over entity names + records, _, _ = await driver.execute_query( + """ + CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector) + YIELD node AS comm, score + MATCH (comm WHERE comm.group_id IN $group_ids) + RETURN + comm.uuid As uuid, + comm.group_id AS group_id, + comm.name AS name, + comm.name_embedding AS name_embedding, + comm.created_at AS created_at, + comm.summary AS summary + ORDER BY score DESC + """, + search_vector=search_vector, + group_ids=group_ids, + limit=limit, + ) + communities = [get_community_node_from_record(record) for record in records] + + return communities + + async def hybrid_node_search( queries: list[str], embeddings: list[list[float]], @@ -371,8 +459,8 @@ async def hybrid_node_search( results: list[list[EntityNode]] = list( await asyncio.gather( - *[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries], - *[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings], + *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries], + *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings], ) ) @@ -490,24 +578,23 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str ) -> list[str]: # use rrf as a preliminary ranker - sorted_uuids = rrf(results) + sorted_uuids = rrf(node_uuids) scores: dict[str, float] = {} # Find the shortest path to center node query = Query(""" - MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) - MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid}) - RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid + MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid}) + RETURN length(p) AS score """) path_results = await asyncio.gather( *[ driver.execute_query( query, - edge_uuid=uuid, + node_uuid=uuid, center_uuid=center_node_uuid, ) for uuid in sorted_uuids @@ -518,15 +605,8 @@ async def node_distance_reranker( records = result[0] record = records[0] if len(records) > 0 else None distance: float = record['score'] if record is not None else float('inf') - if record is not None and ( - record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid - ): - distance = 0 - - if uuid in scores: - scores[uuid] = min(distance, scores[uuid]) - else: - scores[uuid] = distance + distance = 0 if uuid == center_node_uuid else distance + scores[uuid] = distance # rerank on shortest distance sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 682f9d50..9c9af450 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -26,6 +26,7 @@ from dotenv import load_dotenv from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.graphiti import Graphiti from graphiti_core.nodes import EntityNode, EpisodicNode +from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF pytestmark = pytest.mark.integration @@ -81,6 +82,17 @@ async def test_graphiti_init(): edges = await graphiti.search('issues with higher ed', group_ids=['1']) logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) + + results = await graphiti._search( + 'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1'] + ) + pretty_results = { + 'edges': [edge.fact for edge in results.edges], + 'nodes': [node.name for node in results.nodes], + 'communities': [community.name for community in results.communities], + } + + logger.info(pretty_results) graphiti.close() diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index 38837f0d..0a260919 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -11,11 +11,11 @@ async def test_hybrid_node_search_deduplication(): # Mock the database driver mock_driver = AsyncMock() - # Mock the entity_fulltext_search and entity_similarity_search functions + # Mock the node_fulltext_search and entity_similarity_search functions with patch( - 'graphiti_core.search.search_utils.entity_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( - 'graphiti_core.search.search_utils.entity_similarity_search' + 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: # Set up mock return values mock_fulltext_search.side_effect = [ @@ -47,9 +47,9 @@ async def test_hybrid_node_search_empty_results(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.entity_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( - 'graphiti_core.search.search_utils.entity_similarity_search' + 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [] mock_similarity_search.return_value = [] @@ -66,9 +66,9 @@ async def test_hybrid_node_search_only_fulltext(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.entity_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( - 'graphiti_core.search.search_utils.entity_similarity_search' + 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [ EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1') @@ -90,9 +90,9 @@ async def test_hybrid_node_search_with_limit(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.entity_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( - 'graphiti_core.search.search_utils.entity_similarity_search' + 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [ EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), @@ -120,8 +120,8 @@ async def test_hybrid_node_search_with_limit(): assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 # Verify that the limit was passed to the search functions - mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2) + mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2) + mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2) @pytest.mark.asyncio @@ -129,9 +129,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.entity_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( - 'graphiti_core.search.search_utils.entity_similarity_search' + 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [ EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), @@ -155,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 - mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4) + mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4) + mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4)