add episode scope to search (#362)

* add episode scope to search

* bump version

* linter

* Update graphiti_core/search/search_helpers.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* mypy

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-04-15 19:27:56 -04:00 committed by GitHub
parent 31a4bfeeb2
commit 45b15a06f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 167 additions and 8 deletions

View file

@ -750,7 +750,7 @@ class Graphiti:
nodes = await get_mentioned_nodes(self.driver, episodes)
return SearchResults(edges=edges, nodes=nodes, communities=[])
return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
if source_node.name_embedding is None:

View file

@ -25,7 +25,7 @@ from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
@ -33,6 +33,8 @@ from graphiti_core.search.search_config import (
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
EpisodeReranker,
EpisodeSearchConfig,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
@ -46,6 +48,7 @@ from graphiti_core.search.search_utils import (
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_fulltext_search,
episode_mentions_reranker,
maximal_marginal_relevance,
node_bfs_search,
@ -74,13 +77,14 @@ async def search(
return SearchResults(
edges=[],
nodes=[],
episodes=[],
communities=[],
)
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
edges, nodes, communities = await semaphore_gather(
edges, nodes, episodes, communities = await semaphore_gather(
edge_search(
driver,
cross_encoder,
@ -107,6 +111,17 @@ async def search(
config.limit,
config.reranker_min_score,
),
episode_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.episode_config,
search_filter,
config.limit,
config.reranker_min_score,
),
community_search(
driver,
cross_encoder,
@ -122,6 +137,7 @@ async def search(
results = SearchResults(
edges=edges,
nodes=nodes,
episodes=episodes,
communities=communities,
)
@ -328,6 +344,54 @@ async def node_search(
return reranked_nodes[:limit]
async def episode_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
_query_vector: list[float],
group_ids: list[str] | None,
config: EpisodeSearchConfig | None,
search_filter: SearchFilters,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> list[EpisodicNode]:
if config is None:
return []
search_results: list[list[EpisodicNode]] = list(
await semaphore_gather(
*[
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
]
)
)
search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
reranked_uuids: list[str] = []
if config.reranker == EpisodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EpisodeReranker.cross_encoder:
# use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
reranked_uuids = [
content_to_uuid_map[content]
for content, score in reranked_contents
if score >= reranker_min_score
]
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_episodes[:limit]
async def community_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,

View file

@ -19,7 +19,7 @@ from enum import Enum
from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_utils import (
DEFAULT_MIN_SCORE,
DEFAULT_MMR_LAMBDA,
@ -41,6 +41,10 @@ class NodeSearchMethod(Enum):
bfs = 'breadth_first_search'
class EpisodeSearchMethod(Enum):
bm25 = 'bm25'
class CommunitySearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
@ -62,6 +66,11 @@ class NodeReranker(Enum):
cross_encoder = 'cross_encoder'
class EpisodeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
cross_encoder = 'cross_encoder'
class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion'
mmr = 'mmr'
@ -84,6 +93,14 @@ class NodeSearchConfig(BaseModel):
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class EpisodeSearchConfig(BaseModel):
search_methods: list[EpisodeSearchMethod]
reranker: EpisodeReranker = Field(default=EpisodeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class CommunitySearchConfig(BaseModel):
search_methods: list[CommunitySearchMethod]
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
@ -95,6 +112,7 @@ class CommunitySearchConfig(BaseModel):
class SearchConfig(BaseModel):
edge_config: EdgeSearchConfig | None = Field(default=None)
node_config: NodeSearchConfig | None = Field(default=None)
episode_config: EpisodeSearchConfig | None = Field(default=None)
community_config: CommunitySearchConfig | None = Field(default=None)
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
reranker_min_score: float = Field(default=0)
@ -103,4 +121,5 @@ class SearchConfig(BaseModel):
class SearchResults(BaseModel):
edges: list[EntityEdge]
nodes: list[EntityNode]
episodes: list[EpisodicNode]
communities: list[CommunityNode]

View file

@ -21,6 +21,9 @@ from graphiti_core.search.search_config import (
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
EpisodeReranker,
EpisodeSearchConfig,
EpisodeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
@ -37,6 +40,12 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.rrf,
),
episode_config=EpisodeSearchConfig(
search_methods=[
EpisodeSearchMethod.bm25,
],
reranker=EpisodeReranker.rrf,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.rrf,
@ -55,6 +64,12 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
reranker=NodeReranker.mmr,
mmr_lambda=1,
),
episode_config=EpisodeSearchConfig(
search_methods=[
EpisodeSearchMethod.bm25,
],
reranker=EpisodeReranker.rrf,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
@ -80,6 +95,12 @@ COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
],
reranker=NodeReranker.cross_encoder,
),
episode_config=EpisodeSearchConfig(
search_methods=[
EpisodeSearchMethod.bm25,
],
reranker=EpisodeReranker.cross_encoder,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.cross_encoder,

View file

@ -38,6 +38,13 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
entity_json = [
{'entity_name': node.name, 'summary': node.summary} for node in search_results.nodes
]
episode_json = [
{
'source_description': episode.source_description,
'content': episode.content,
}
for episode in search_results.episodes
]
community_json = [
{'community_name': community.name, 'summary': community.summary}
for community in search_results.communities
@ -55,6 +62,9 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
<ENTITIES>
{json.dumps(entity_json, indent=12)}
</ENTITIES>
<EPISODES>
{json.dumps(episode_json, indent=12)}
</EPISODES>
<COMMUNITIES>
{json.dumps(community_json, indent=12)}
</COMMUNITIES>

View file

@ -37,6 +37,7 @@ from graphiti_core.nodes import (
EpisodicNode,
get_community_node_from_record,
get_entity_node_from_record,
get_episodic_node_from_record,
)
from graphiti_core.search.search_filters import (
SearchFilters,
@ -229,8 +230,8 @@ async def edge_similarity_search(
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -475,6 +476,48 @@ async def node_bfs_search(
return nodes
async def episode_fulltext_search(
driver: AsyncDriver,
query: str,
_search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
# BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids)
if fuzzy_query == '':
return []
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
YIELD node AS episode, score
MATCH (e:Episodic)
WHERE e.uuid = episode.uuid
RETURN
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
async def community_fulltext_search(
driver: AsyncDriver,
query: str,

View file

@ -71,6 +71,8 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
]
fulltext_indices: list[LiteralString] = [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.9.6"
version = "0.10.0"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

View file

@ -66,7 +66,7 @@ async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
results = await graphiti.search_(
query='Who is the user?',
query='Who is the User?',
)
pretty_results = search_results_to_context_string(results)