diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index 635200af..2b3e0dbc 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -750,7 +750,7 @@ class Graphiti:
nodes = await get_mentioned_nodes(self.driver, episodes)
- return SearchResults(edges=edges, nodes=nodes, communities=[])
+ return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
if source_node.name_embedding is None:
diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py
index 709b86df..dc353211 100644
--- a/graphiti_core/search/search.py
+++ b/graphiti_core/search/search.py
@@ -25,7 +25,7 @@ from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
from graphiti_core.helpers import semaphore_gather
-from graphiti_core.nodes import CommunityNode, EntityNode
+from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
@@ -33,6 +33,8 @@ from graphiti_core.search.search_config import (
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
+ EpisodeReranker,
+ EpisodeSearchConfig,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
@@ -46,6 +48,7 @@ from graphiti_core.search.search_utils import (
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
+ episode_fulltext_search,
episode_mentions_reranker,
maximal_marginal_relevance,
node_bfs_search,
@@ -74,13 +77,14 @@ async def search(
return SearchResults(
edges=[],
nodes=[],
+ episodes=[],
communities=[],
)
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
- edges, nodes, communities = await semaphore_gather(
+ edges, nodes, episodes, communities = await semaphore_gather(
edge_search(
driver,
cross_encoder,
@@ -107,6 +111,17 @@ async def search(
config.limit,
config.reranker_min_score,
),
+ episode_search(
+ driver,
+ cross_encoder,
+ query,
+ query_vector,
+ group_ids,
+ config.episode_config,
+ search_filter,
+ config.limit,
+ config.reranker_min_score,
+ ),
community_search(
driver,
cross_encoder,
@@ -122,6 +137,7 @@ async def search(
results = SearchResults(
edges=edges,
nodes=nodes,
+ episodes=episodes,
communities=communities,
)
@@ -328,6 +344,54 @@ async def node_search(
return reranked_nodes[:limit]
+async def episode_search(
+ driver: AsyncDriver,
+ cross_encoder: CrossEncoderClient,
+ query: str,
+ _query_vector: list[float],
+ group_ids: list[str] | None,
+ config: EpisodeSearchConfig | None,
+ search_filter: SearchFilters,
+ limit=DEFAULT_SEARCH_LIMIT,
+ reranker_min_score: float = 0,
+) -> list[EpisodicNode]:
+ if config is None:
+ return []
+
+ search_results: list[list[EpisodicNode]] = list(
+ await semaphore_gather(
+ *[
+ episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
+ ]
+ )
+ )
+
+ search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
+ episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
+
+ reranked_uuids: list[str] = []
+ if config.reranker == EpisodeReranker.rrf:
+ reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
+
+ elif config.reranker == EpisodeReranker.cross_encoder:
+ # use rrf as a preliminary reranker
+ rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
+ rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
+
+ content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
+
+ reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
+ reranked_uuids = [
+ content_to_uuid_map[content]
+ for content, score in reranked_contents
+ if score >= reranker_min_score
+ ]
+
+ reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
+
+ return reranked_episodes[:limit]
+
+
async def community_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py
index f0c21bde..63b1a114 100644
--- a/graphiti_core/search/search_config.py
+++ b/graphiti_core/search/search_config.py
@@ -19,7 +19,7 @@ from enum import Enum
from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge
-from graphiti_core.nodes import CommunityNode, EntityNode
+from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_utils import (
DEFAULT_MIN_SCORE,
DEFAULT_MMR_LAMBDA,
@@ -41,6 +41,10 @@ class NodeSearchMethod(Enum):
bfs = 'breadth_first_search'
+class EpisodeSearchMethod(Enum):
+ bm25 = 'bm25'
+
+
class CommunitySearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
@@ -62,6 +66,11 @@ class NodeReranker(Enum):
cross_encoder = 'cross_encoder'
+class EpisodeReranker(Enum):
+ rrf = 'reciprocal_rank_fusion'
+ cross_encoder = 'cross_encoder'
+
+
class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion'
mmr = 'mmr'
@@ -84,6 +93,14 @@ class NodeSearchConfig(BaseModel):
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
+class EpisodeSearchConfig(BaseModel):
+ search_methods: list[EpisodeSearchMethod]
+ reranker: EpisodeReranker = Field(default=EpisodeReranker.rrf)
+ sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
+ mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
+
+
class CommunitySearchConfig(BaseModel):
search_methods: list[CommunitySearchMethod]
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
@@ -95,6 +112,7 @@ class CommunitySearchConfig(BaseModel):
class SearchConfig(BaseModel):
edge_config: EdgeSearchConfig | None = Field(default=None)
node_config: NodeSearchConfig | None = Field(default=None)
+ episode_config: EpisodeSearchConfig | None = Field(default=None)
community_config: CommunitySearchConfig | None = Field(default=None)
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
reranker_min_score: float = Field(default=0)
@@ -103,4 +121,5 @@ class SearchConfig(BaseModel):
class SearchResults(BaseModel):
edges: list[EntityEdge]
nodes: list[EntityNode]
+ episodes: list[EpisodicNode]
communities: list[CommunityNode]
diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py
index 06b6f8cb..4b8bc5a3 100644
--- a/graphiti_core/search/search_config_recipes.py
+++ b/graphiti_core/search/search_config_recipes.py
@@ -21,6 +21,9 @@ from graphiti_core.search.search_config import (
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
+ EpisodeReranker,
+ EpisodeSearchConfig,
+ EpisodeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
@@ -37,6 +40,12 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.rrf,
),
+ episode_config=EpisodeSearchConfig(
+ search_methods=[
+ EpisodeSearchMethod.bm25,
+ ],
+ reranker=EpisodeReranker.rrf,
+ ),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.rrf,
@@ -55,6 +64,12 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
reranker=NodeReranker.mmr,
mmr_lambda=1,
),
+ episode_config=EpisodeSearchConfig(
+ search_methods=[
+ EpisodeSearchMethod.bm25,
+ ],
+ reranker=EpisodeReranker.rrf,
+ ),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
@@ -80,6 +95,12 @@ COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
],
reranker=NodeReranker.cross_encoder,
),
+ episode_config=EpisodeSearchConfig(
+ search_methods=[
+ EpisodeSearchMethod.bm25,
+ ],
+ reranker=EpisodeReranker.cross_encoder,
+ ),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.cross_encoder,
diff --git a/graphiti_core/search/search_helpers.py b/graphiti_core/search/search_helpers.py
index a034847c..37d31b0b 100644
--- a/graphiti_core/search/search_helpers.py
+++ b/graphiti_core/search/search_helpers.py
@@ -38,6 +38,13 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
entity_json = [
{'entity_name': node.name, 'summary': node.summary} for node in search_results.nodes
]
+ episode_json = [
+ {
+ 'source_description': episode.source_description,
+ 'content': episode.content,
+ }
+ for episode in search_results.episodes
+ ]
community_json = [
{'community_name': community.name, 'summary': community.summary}
for community in search_results.communities
@@ -55,6 +62,9 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
{json.dumps(entity_json, indent=12)}
+
+ {json.dumps(episode_json, indent=12)}
+
{json.dumps(community_json, indent=12)}
diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py
index 820fb00f..a7f3b7ad 100644
--- a/graphiti_core/search/search_utils.py
+++ b/graphiti_core/search/search_utils.py
@@ -37,6 +37,7 @@ from graphiti_core.nodes import (
EpisodicNode,
get_community_node_from_record,
get_entity_node_from_record,
+ get_episodic_node_from_record,
)
from graphiti_core.search.search_filters import (
SearchFilters,
@@ -229,8 +230,8 @@ async def edge_similarity_search(
query: LiteralString = (
"""
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
- """
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
+ """
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@@ -475,6 +476,48 @@ async def node_bfs_search(
return nodes
+async def episode_fulltext_search(
+ driver: AsyncDriver,
+ query: str,
+ _search_filter: SearchFilters,
+ group_ids: list[str] | None = None,
+ limit=RELEVANT_SCHEMA_LIMIT,
+) -> list[EpisodicNode]:
+ # BM25 search to get top episodes
+ fuzzy_query = fulltext_query(query, group_ids)
+ if fuzzy_query == '':
+ return []
+
+ records, _, _ = await driver.execute_query(
+ """
+ CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
+ YIELD node AS episode, score
+ MATCH (e:Episodic)
+ WHERE e.uuid = episode.uuid
+ RETURN
+ e.content AS content,
+ e.created_at AS created_at,
+ e.valid_at AS valid_at,
+ e.uuid AS uuid,
+ e.name AS name,
+ e.group_id AS group_id,
+ e.source_description AS source_description,
+ e.source AS source,
+ e.entity_edges AS entity_edges
+ ORDER BY score DESC
+ LIMIT $limit
+ """,
+ query=fuzzy_query,
+ group_ids=group_ids,
+ limit=limit,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
+ episodes = [get_episodic_node_from_record(record) for record in records]
+
+ return episodes
+
+
async def community_fulltext_search(
driver: AsyncDriver,
query: str,
diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py
index 4f570b0d..226d8f41 100644
--- a/graphiti_core/utils/maintenance/graph_data_operations.py
+++ b/graphiti_core/utils/maintenance/graph_data_operations.py
@@ -71,6 +71,8 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
]
fulltext_indices: list[LiteralString] = [
+ """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
+ FOR (e:Episodic) ON EACH [e.content, e.source, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
diff --git a/pyproject.toml b/pyproject.toml
index 1bd9e830..9d966107 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
-version = "0.9.6"
+version = "0.10.0"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py
index 6739bf06..a2c7c378 100644
--- a/tests/test_graphiti_int.py
+++ b/tests/test_graphiti_int.py
@@ -66,7 +66,7 @@ async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
results = await graphiti.search_(
- query='Who is the user?',
+ query='Who is the User?',
)
pretty_results = search_results_to_context_string(results)