add episode scope to search (#362)
* add episode scope to search * bump version * linter * Update graphiti_core/search/search_helpers.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
31a4bfeeb2
commit
45b15a06f2
9 changed files with 167 additions and 8 deletions
|
|
@ -750,7 +750,7 @@ class Graphiti:
|
|||
|
||||
nodes = await get_mentioned_nodes(self.driver, episodes)
|
||||
|
||||
return SearchResults(edges=edges, nodes=nodes, communities=[])
|
||||
return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
|
||||
|
||||
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
|
||||
if source_node.name_embedding is None:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from graphiti_core.edges import EntityEdge
|
|||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config import (
|
||||
DEFAULT_SEARCH_LIMIT,
|
||||
CommunityReranker,
|
||||
|
|
@ -33,6 +33,8 @@ from graphiti_core.search.search_config import (
|
|||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
EpisodeReranker,
|
||||
EpisodeSearchConfig,
|
||||
NodeReranker,
|
||||
NodeSearchConfig,
|
||||
NodeSearchMethod,
|
||||
|
|
@ -46,6 +48,7 @@ from graphiti_core.search.search_utils import (
|
|||
edge_bfs_search,
|
||||
edge_fulltext_search,
|
||||
edge_similarity_search,
|
||||
episode_fulltext_search,
|
||||
episode_mentions_reranker,
|
||||
maximal_marginal_relevance,
|
||||
node_bfs_search,
|
||||
|
|
@ -74,13 +77,14 @@ async def search(
|
|||
return SearchResults(
|
||||
edges=[],
|
||||
nodes=[],
|
||||
episodes=[],
|
||||
communities=[],
|
||||
)
|
||||
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids else None
|
||||
edges, nodes, communities = await semaphore_gather(
|
||||
edges, nodes, episodes, communities = await semaphore_gather(
|
||||
edge_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
|
|
@ -107,6 +111,17 @@ async def search(
|
|||
config.limit,
|
||||
config.reranker_min_score,
|
||||
),
|
||||
episode_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.episode_config,
|
||||
search_filter,
|
||||
config.limit,
|
||||
config.reranker_min_score,
|
||||
),
|
||||
community_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
|
|
@ -122,6 +137,7 @@ async def search(
|
|||
results = SearchResults(
|
||||
edges=edges,
|
||||
nodes=nodes,
|
||||
episodes=episodes,
|
||||
communities=communities,
|
||||
)
|
||||
|
||||
|
|
@ -328,6 +344,54 @@ async def node_search(
|
|||
return reranked_nodes[:limit]
|
||||
|
||||
|
||||
async def episode_search(
|
||||
driver: AsyncDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
_query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: EpisodeSearchConfig | None,
|
||||
search_filter: SearchFilters,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
reranker_min_score: float = 0,
|
||||
) -> list[EpisodicNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EpisodicNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
|
||||
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == EpisodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
|
||||
elif config.reranker == EpisodeReranker.cross_encoder:
|
||||
# use rrf as a preliminary reranker
|
||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
|
||||
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
||||
|
||||
reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
|
||||
reranked_uuids = [
|
||||
content_to_uuid_map[content]
|
||||
for content, score in reranked_contents
|
||||
if score >= reranker_min_score
|
||||
]
|
||||
|
||||
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_episodes[:limit]
|
||||
|
||||
|
||||
async def community_search(
|
||||
driver: AsyncDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from enum import Enum
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_utils import (
|
||||
DEFAULT_MIN_SCORE,
|
||||
DEFAULT_MMR_LAMBDA,
|
||||
|
|
@ -41,6 +41,10 @@ class NodeSearchMethod(Enum):
|
|||
bfs = 'breadth_first_search'
|
||||
|
||||
|
||||
class EpisodeSearchMethod(Enum):
|
||||
bm25 = 'bm25'
|
||||
|
||||
|
||||
class CommunitySearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
|
|
@ -62,6 +66,11 @@ class NodeReranker(Enum):
|
|||
cross_encoder = 'cross_encoder'
|
||||
|
||||
|
||||
class EpisodeReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
cross_encoder = 'cross_encoder'
|
||||
|
||||
|
||||
class CommunityReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
mmr = 'mmr'
|
||||
|
|
@ -84,6 +93,14 @@ class NodeSearchConfig(BaseModel):
|
|||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
|
||||
|
||||
class EpisodeSearchConfig(BaseModel):
|
||||
search_methods: list[EpisodeSearchMethod]
|
||||
reranker: EpisodeReranker = Field(default=EpisodeReranker.rrf)
|
||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
|
||||
|
||||
class CommunitySearchConfig(BaseModel):
|
||||
search_methods: list[CommunitySearchMethod]
|
||||
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
||||
|
|
@ -95,6 +112,7 @@ class CommunitySearchConfig(BaseModel):
|
|||
class SearchConfig(BaseModel):
|
||||
edge_config: EdgeSearchConfig | None = Field(default=None)
|
||||
node_config: NodeSearchConfig | None = Field(default=None)
|
||||
episode_config: EpisodeSearchConfig | None = Field(default=None)
|
||||
community_config: CommunitySearchConfig | None = Field(default=None)
|
||||
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
||||
reranker_min_score: float = Field(default=0)
|
||||
|
|
@ -103,4 +121,5 @@ class SearchConfig(BaseModel):
|
|||
class SearchResults(BaseModel):
|
||||
edges: list[EntityEdge]
|
||||
nodes: list[EntityNode]
|
||||
episodes: list[EpisodicNode]
|
||||
communities: list[CommunityNode]
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ from graphiti_core.search.search_config import (
|
|||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
EpisodeReranker,
|
||||
EpisodeSearchConfig,
|
||||
EpisodeSearchMethod,
|
||||
NodeReranker,
|
||||
NodeSearchConfig,
|
||||
NodeSearchMethod,
|
||||
|
|
@ -37,6 +40,12 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
|||
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||
reranker=NodeReranker.rrf,
|
||||
),
|
||||
episode_config=EpisodeSearchConfig(
|
||||
search_methods=[
|
||||
EpisodeSearchMethod.bm25,
|
||||
],
|
||||
reranker=EpisodeReranker.rrf,
|
||||
),
|
||||
community_config=CommunitySearchConfig(
|
||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||
reranker=CommunityReranker.rrf,
|
||||
|
|
@ -55,6 +64,12 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
|||
reranker=NodeReranker.mmr,
|
||||
mmr_lambda=1,
|
||||
),
|
||||
episode_config=EpisodeSearchConfig(
|
||||
search_methods=[
|
||||
EpisodeSearchMethod.bm25,
|
||||
],
|
||||
reranker=EpisodeReranker.rrf,
|
||||
),
|
||||
community_config=CommunitySearchConfig(
|
||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||
reranker=CommunityReranker.mmr,
|
||||
|
|
@ -80,6 +95,12 @@ COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
|
|||
],
|
||||
reranker=NodeReranker.cross_encoder,
|
||||
),
|
||||
episode_config=EpisodeSearchConfig(
|
||||
search_methods=[
|
||||
EpisodeSearchMethod.bm25,
|
||||
],
|
||||
reranker=EpisodeReranker.cross_encoder,
|
||||
),
|
||||
community_config=CommunitySearchConfig(
|
||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||
reranker=CommunityReranker.cross_encoder,
|
||||
|
|
|
|||
|
|
@ -38,6 +38,13 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
|
|||
entity_json = [
|
||||
{'entity_name': node.name, 'summary': node.summary} for node in search_results.nodes
|
||||
]
|
||||
episode_json = [
|
||||
{
|
||||
'source_description': episode.source_description,
|
||||
'content': episode.content,
|
||||
}
|
||||
for episode in search_results.episodes
|
||||
]
|
||||
community_json = [
|
||||
{'community_name': community.name, 'summary': community.summary}
|
||||
for community in search_results.communities
|
||||
|
|
@ -55,6 +62,9 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
|
|||
<ENTITIES>
|
||||
{json.dumps(entity_json, indent=12)}
|
||||
</ENTITIES>
|
||||
<EPISODES>
|
||||
{json.dumps(episode_json, indent=12)}
|
||||
</EPISODES>
|
||||
<COMMUNITIES>
|
||||
{json.dumps(community_json, indent=12)}
|
||||
</COMMUNITIES>
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from graphiti_core.nodes import (
|
|||
EpisodicNode,
|
||||
get_community_node_from_record,
|
||||
get_entity_node_from_record,
|
||||
get_episodic_node_from_record,
|
||||
)
|
||||
from graphiti_core.search.search_filters import (
|
||||
SearchFilters,
|
||||
|
|
@ -229,8 +230,8 @@ async def edge_similarity_search(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
|
|
@ -475,6 +476,48 @@ async def node_bfs_search(
|
|||
return nodes
|
||||
|
||||
|
||||
async def episode_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
_search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EpisodicNode]:
|
||||
# BM25 search to get top episodes
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
|
||||
YIELD node AS episode, score
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.uuid = episode.uuid
|
||||
RETURN
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
||||
return episodes
|
||||
|
||||
|
||||
async def community_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
|
|
|
|||
|
|
@ -71,6 +71,8 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|||
]
|
||||
|
||||
fulltext_indices: list[LiteralString] = [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.9.6"
|
||||
version = "0.10.0"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ async def test_graphiti_init():
|
|||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
|
||||
results = await graphiti.search_(
|
||||
query='Who is the user?',
|
||||
query='Who is the User?',
|
||||
)
|
||||
|
||||
pretty_results = search_results_to_context_string(results)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue