From 45b15a06f20c57d20bd29c0f4edc74df22c9d972 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 15 Apr 2025 19:27:56 -0400 Subject: [PATCH] add episode scope to search (#362) * add episode scope to search * bump version * linter * Update graphiti_core/search/search_helpers.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- graphiti_core/graphiti.py | 2 +- graphiti_core/search/search.py | 68 ++++++++++++++++++- graphiti_core/search/search_config.py | 21 +++++- graphiti_core/search/search_config_recipes.py | 21 ++++++ graphiti_core/search/search_helpers.py | 10 +++ graphiti_core/search/search_utils.py | 47 ++++++++++++- .../maintenance/graph_data_operations.py | 2 + pyproject.toml | 2 +- tests/test_graphiti_int.py | 2 +- 9 files changed, 167 insertions(+), 8 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 635200af..2b3e0dbc 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -750,7 +750,7 @@ class Graphiti: nodes = await get_mentioned_nodes(self.driver, episodes) - return SearchResults(edges=edges, nodes=nodes, communities=[]) + return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[]) async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode): if source_node.name_embedding is None: diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 709b86df..dc353211 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -25,7 +25,7 @@ from graphiti_core.edges import EntityEdge from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import SearchRerankerError from graphiti_core.helpers import semaphore_gather -from graphiti_core.nodes import CommunityNode, EntityNode +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.search.search_config import ( DEFAULT_SEARCH_LIMIT, CommunityReranker, @@ -33,6 +33,8 @@ from graphiti_core.search.search_config import ( EdgeReranker, EdgeSearchConfig, EdgeSearchMethod, + EpisodeReranker, + EpisodeSearchConfig, NodeReranker, NodeSearchConfig, NodeSearchMethod, @@ -46,6 +48,7 @@ from graphiti_core.search.search_utils import ( edge_bfs_search, edge_fulltext_search, edge_similarity_search, + episode_fulltext_search, episode_mentions_reranker, maximal_marginal_relevance, node_bfs_search, @@ -74,13 +77,14 @@ async def search( return SearchResults( edges=[], nodes=[], + episodes=[], communities=[], ) query_vector = await embedder.create(input_data=[query.replace('\n', ' ')]) # if group_ids is empty, set it to None group_ids = group_ids if group_ids else None - edges, nodes, communities = await semaphore_gather( + edges, nodes, episodes, communities = await semaphore_gather( edge_search( driver, cross_encoder, @@ -107,6 +111,17 @@ async def search( config.limit, config.reranker_min_score, ), + episode_search( + driver, + cross_encoder, + query, + query_vector, + group_ids, + config.episode_config, + search_filter, + config.limit, + config.reranker_min_score, + ), community_search( driver, cross_encoder, @@ -122,6 +137,7 @@ async def search( results = SearchResults( edges=edges, nodes=nodes, + episodes=episodes, communities=communities, ) @@ -328,6 +344,54 @@ async def node_search( return reranked_nodes[:limit] +async def episode_search( + driver: AsyncDriver, + cross_encoder: CrossEncoderClient, + query: str, + _query_vector: list[float], + group_ids: list[str] | None, + config: EpisodeSearchConfig | None, + search_filter: SearchFilters, + limit=DEFAULT_SEARCH_LIMIT, + reranker_min_score: float = 0, +) -> list[EpisodicNode]: + if config is None: + return [] + + search_results: list[list[EpisodicNode]] = list( + await semaphore_gather( + *[ + episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), + ] + ) + ) + + search_result_uuids = [[episode.uuid for episode in result] for result in search_results] + episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result} + + reranked_uuids: list[str] = [] + if config.reranker == EpisodeReranker.rrf: + reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) + + elif config.reranker == EpisodeReranker.cross_encoder: + # use rrf as a preliminary reranker + rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score) + rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] + + content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results} + + reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys())) + reranked_uuids = [ + content_to_uuid_map[content] + for content, score in reranked_contents + if score >= reranker_min_score + ] + + reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids] + + return reranked_episodes[:limit] + + async def community_search( driver: AsyncDriver, cross_encoder: CrossEncoderClient, diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index f0c21bde..63b1a114 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -19,7 +19,7 @@ from enum import Enum from pydantic import BaseModel, Field from graphiti_core.edges import EntityEdge -from graphiti_core.nodes import CommunityNode, EntityNode +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.search.search_utils import ( DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA, @@ -41,6 +41,10 @@ class NodeSearchMethod(Enum): bfs = 'breadth_first_search' +class EpisodeSearchMethod(Enum): + bm25 = 'bm25' + + class CommunitySearchMethod(Enum): cosine_similarity = 'cosine_similarity' bm25 = 'bm25' @@ -62,6 +66,11 @@ class NodeReranker(Enum): cross_encoder = 'cross_encoder' +class EpisodeReranker(Enum): + rrf = 'reciprocal_rank_fusion' + cross_encoder = 'cross_encoder' + + class CommunityReranker(Enum): rrf = 'reciprocal_rank_fusion' mmr = 'mmr' @@ -84,6 +93,14 @@ class NodeSearchConfig(BaseModel): bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) +class EpisodeSearchConfig(BaseModel): + search_methods: list[EpisodeSearchMethod] + reranker: EpisodeReranker = Field(default=EpisodeReranker.rrf) + sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) + mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) + bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) + + class CommunitySearchConfig(BaseModel): search_methods: list[CommunitySearchMethod] reranker: CommunityReranker = Field(default=CommunityReranker.rrf) @@ -95,6 +112,7 @@ class CommunitySearchConfig(BaseModel): class SearchConfig(BaseModel): edge_config: EdgeSearchConfig | None = Field(default=None) node_config: NodeSearchConfig | None = Field(default=None) + episode_config: EpisodeSearchConfig | None = Field(default=None) community_config: CommunitySearchConfig | None = Field(default=None) limit: int = Field(default=DEFAULT_SEARCH_LIMIT) reranker_min_score: float = Field(default=0) @@ -103,4 +121,5 @@ class SearchConfig(BaseModel): class SearchResults(BaseModel): edges: list[EntityEdge] nodes: list[EntityNode] + episodes: list[EpisodicNode] communities: list[CommunityNode] diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py index 06b6f8cb..4b8bc5a3 100644 --- a/graphiti_core/search/search_config_recipes.py +++ b/graphiti_core/search/search_config_recipes.py @@ -21,6 +21,9 @@ from graphiti_core.search.search_config import ( EdgeReranker, EdgeSearchConfig, EdgeSearchMethod, + EpisodeReranker, + EpisodeSearchConfig, + EpisodeSearchMethod, NodeReranker, NodeSearchConfig, NodeSearchMethod, @@ -37,6 +40,12 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig( search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], reranker=NodeReranker.rrf, ), + episode_config=EpisodeSearchConfig( + search_methods=[ + EpisodeSearchMethod.bm25, + ], + reranker=EpisodeReranker.rrf, + ), community_config=CommunitySearchConfig( search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], reranker=CommunityReranker.rrf, @@ -55,6 +64,12 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig( reranker=NodeReranker.mmr, mmr_lambda=1, ), + episode_config=EpisodeSearchConfig( + search_methods=[ + EpisodeSearchMethod.bm25, + ], + reranker=EpisodeReranker.rrf, + ), community_config=CommunitySearchConfig( search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], reranker=CommunityReranker.mmr, @@ -80,6 +95,12 @@ COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig( ], reranker=NodeReranker.cross_encoder, ), + episode_config=EpisodeSearchConfig( + search_methods=[ + EpisodeSearchMethod.bm25, + ], + reranker=EpisodeReranker.cross_encoder, + ), community_config=CommunitySearchConfig( search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], reranker=CommunityReranker.cross_encoder, diff --git a/graphiti_core/search/search_helpers.py b/graphiti_core/search/search_helpers.py index a034847c..37d31b0b 100644 --- a/graphiti_core/search/search_helpers.py +++ b/graphiti_core/search/search_helpers.py @@ -38,6 +38,13 @@ def search_results_to_context_string(search_results: SearchResults) -> str: entity_json = [ {'entity_name': node.name, 'summary': node.summary} for node in search_results.nodes ] + episode_json = [ + { + 'source_description': episode.source_description, + 'content': episode.content, + } + for episode in search_results.episodes + ] community_json = [ {'community_name': community.name, 'summary': community.summary} for community in search_results.communities @@ -55,6 +62,9 @@ def search_results_to_context_string(search_results: SearchResults) -> str: {json.dumps(entity_json, indent=12)} + + {json.dumps(episode_json, indent=12)} + {json.dumps(community_json, indent=12)} diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 820fb00f..a7f3b7ad 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -37,6 +37,7 @@ from graphiti_core.nodes import ( EpisodicNode, get_community_node_from_record, get_entity_node_from_record, + get_episodic_node_from_record, ) from graphiti_core.search.search_filters import ( SearchFilters, @@ -229,8 +230,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -475,6 +476,48 @@ async def node_bfs_search( return nodes +async def episode_fulltext_search( + driver: AsyncDriver, + query: str, + _search_filter: SearchFilters, + group_ids: list[str] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, +) -> list[EpisodicNode]: + # BM25 search to get top episodes + fuzzy_query = fulltext_query(query, group_ids) + if fuzzy_query == '': + return [] + + records, _, _ = await driver.execute_query( + """ + CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit}) + YIELD node AS episode, score + MATCH (e:Episodic) + WHERE e.uuid = episode.uuid + RETURN + e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id, + e.source_description AS source_description, + e.source AS source, + e.entity_edges AS entity_edges + ORDER BY score DESC + LIMIT $limit + """, + query=fuzzy_query, + group_ids=group_ids, + limit=limit, + database_=DEFAULT_DATABASE, + routing_='r', + ) + episodes = [get_episodic_node_from_record(record) for record in records] + + return episodes + + async def community_fulltext_search( driver: AsyncDriver, query: str, diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 4f570b0d..226d8f41 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -71,6 +71,8 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo ] fulltext_indices: list[LiteralString] = [ + """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS + FOR (e:Episodic) ON EACH [e.content, e.source, e.group_id]""", """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", """CREATE FULLTEXT INDEX community_name IF NOT EXISTS diff --git a/pyproject.toml b/pyproject.toml index 1bd9e830..9d966107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.9.6" +version = "0.10.0" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 6739bf06..a2c7c378 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -66,7 +66,7 @@ async def test_graphiti_init(): graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) results = await graphiti.search_( - query='Who is the user?', + query='Who is the User?', ) pretty_results = search_results_to_context_string(results)