Mentions reranker (#124)
* documentation update * update communities * mentions reranker * fix episode edge mentions * get episode mentions * add communities to mentions endpoint * rebase * defaults episodes to empty list * update
This commit is contained in:
parent
d133c39313
commit
e398f95612
9 changed files with 184 additions and 7 deletions
|
|
@ -83,6 +83,7 @@ async def main(use_bulk: bool = True):
|
||||||
reference_time=message.actual_timestamp,
|
reference_time=message.actual_timestamp,
|
||||||
source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
|
update_communities=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -109,13 +109,36 @@ class EpisodicEdge(Edge):
|
||||||
raise EdgeNotFoundError(uuid)
|
raise EdgeNotFoundError(uuid)
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||||
|
WHERE e.uuid IN $uuids
|
||||||
|
RETURN
|
||||||
|
e.uuid As uuid,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
e.created_at AS created_at
|
||||||
|
""",
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Edges: {uuids}')
|
||||||
|
if len(edges) == 0:
|
||||||
|
raise EdgeNotFoundError(uuids[0])
|
||||||
|
return edges
|
||||||
|
|
||||||
|
|
||||||
class EntityEdge(Edge):
|
class EntityEdge(Edge):
|
||||||
name: str = Field(description='name of the edge, relation name')
|
name: str = Field(description='name of the edge, relation name')
|
||||||
fact: str = Field(description='fact representing the edge and nodes that it connects')
|
fact: str = Field(description='fact representing the edge and nodes that it connects')
|
||||||
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
|
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
|
||||||
episodes: list[str] | None = Field(
|
episodes: list[str] = Field(
|
||||||
default=None,
|
default=[],
|
||||||
description='list of episode ids that reference these entity edges',
|
description='list of episode ids that reference these entity edges',
|
||||||
)
|
)
|
||||||
expired_at: datetime | None = Field(
|
expired_at: datetime | None = Field(
|
||||||
|
|
@ -197,6 +220,36 @@ class EntityEdge(Edge):
|
||||||
raise EdgeNotFoundError(uuid)
|
raise EdgeNotFoundError(uuid)
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
|
WHERE e.uuid IN $uuids
|
||||||
|
RETURN
|
||||||
|
e.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.name AS name,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
e.fact AS fact,
|
||||||
|
e.fact_embedding AS fact_embedding,
|
||||||
|
e.episodes AS episodes,
|
||||||
|
e.expired_at AS expired_at,
|
||||||
|
e.valid_at AS valid_at,
|
||||||
|
e.invalid_at AS invalid_at
|
||||||
|
""",
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Edges: {uuids}')
|
||||||
|
if len(edges) == 0:
|
||||||
|
raise EdgeNotFoundError(uuids[0])
|
||||||
|
return edges
|
||||||
|
|
||||||
|
|
||||||
class CommunityEdge(Edge):
|
class CommunityEdge(Edge):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
|
|
@ -239,6 +292,28 @@ class CommunityEdge(Edge):
|
||||||
|
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||||
|
WHERE e.uuid IN $uuids
|
||||||
|
RETURN
|
||||||
|
e.uuid As uuid,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
e.created_at AS created_at
|
||||||
|
""",
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Edges: {uuids}')
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
||||||
|
|
||||||
# Edge helpers
|
# Edge helpers
|
||||||
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,8 @@ from graphiti_core.search.search_config_recipes import (
|
||||||
)
|
)
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
|
get_communities_by_nodes,
|
||||||
|
get_mentioned_nodes,
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
get_relevant_nodes,
|
get_relevant_nodes,
|
||||||
)
|
)
|
||||||
|
|
@ -249,8 +251,6 @@ class Graphiti:
|
||||||
An id for the graph partition the episode is a part of.
|
An id for the graph partition the episode is a part of.
|
||||||
uuid : str | None
|
uuid : str | None
|
||||||
Optional uuid of the episode.
|
Optional uuid of the episode.
|
||||||
update_communities: bool
|
|
||||||
Optional. Determines if we should update communities
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -413,6 +413,8 @@ class Graphiti:
|
||||||
|
|
||||||
logger.info(f'Built episodic edges: {episodic_edges}')
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
||||||
|
|
||||||
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||||
|
|
||||||
# Future optimization would be using batch operations to save nodes and edges
|
# Future optimization would be using batch operations to save nodes and edges
|
||||||
await episode.save(self.driver)
|
await episode.save(self.driver)
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||||
|
|
@ -680,3 +682,19 @@ class Graphiti:
|
||||||
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
||||||
).nodes
|
).nodes
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
||||||
|
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||||
|
|
||||||
|
edges_list = await asyncio.gather(
|
||||||
|
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
||||||
|
)
|
||||||
|
|
||||||
|
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
||||||
|
|
||||||
|
nodes = await get_mentioned_nodes(self.driver, episodes)
|
||||||
|
|
||||||
|
communities = await get_communities_by_nodes(self.driver, nodes)
|
||||||
|
|
||||||
|
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,8 @@ class EpisodicNode(Node):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
||||||
RETURN e.content AS content,
|
RETURN DISTINCT
|
||||||
|
e.content AS content,
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
e.valid_at AS valid_at,
|
e.valid_at AS valid_at,
|
||||||
e.uuid AS uuid,
|
e.uuid AS uuid,
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ from graphiti_core.search.search_utils import (
|
||||||
community_similarity_search,
|
community_similarity_search,
|
||||||
edge_fulltext_search,
|
edge_fulltext_search,
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
|
episode_mentions_reranker,
|
||||||
node_distance_reranker,
|
node_distance_reranker,
|
||||||
node_fulltext_search,
|
node_fulltext_search,
|
||||||
node_similarity_search,
|
node_similarity_search,
|
||||||
|
|
@ -131,7 +132,7 @@ async def edge_search(
|
||||||
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||||
|
|
||||||
reranked_uuids: list[str] = []
|
reranked_uuids: list[str] = []
|
||||||
if config.reranker == EdgeReranker.rrf:
|
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
||||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||||
|
|
||||||
reranked_uuids = rrf(search_result_uuids)
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
|
@ -150,6 +151,9 @@ async def edge_search(
|
||||||
|
|
||||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
if config.reranker == EdgeReranker.episode_mentions:
|
||||||
|
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
||||||
|
|
||||||
return reranked_edges
|
return reranked_edges
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -189,6 +193,8 @@ async def node_search(
|
||||||
reranked_uuids: list[str] = []
|
reranked_uuids: list[str] = []
|
||||||
if config.reranker == NodeReranker.rrf:
|
if config.reranker == NodeReranker.rrf:
|
||||||
reranked_uuids = rrf(search_result_uuids)
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
elif config.reranker == NodeReranker.episode_mentions:
|
||||||
|
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
||||||
elif config.reranker == NodeReranker.node_distance:
|
elif config.reranker == NodeReranker.node_distance:
|
||||||
if center_node_uuid is None:
|
if center_node_uuid is None:
|
||||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||||
|
|
|
||||||
|
|
@ -42,11 +42,13 @@ class CommunitySearchMethod(Enum):
|
||||||
class EdgeReranker(Enum):
|
class EdgeReranker(Enum):
|
||||||
rrf = 'reciprocal_rank_fusion'
|
rrf = 'reciprocal_rank_fusion'
|
||||||
node_distance = 'node_distance'
|
node_distance = 'node_distance'
|
||||||
|
episode_mentions = 'episode_mentions'
|
||||||
|
|
||||||
|
|
||||||
class NodeReranker(Enum):
|
class NodeReranker(Enum):
|
||||||
rrf = 'reciprocal_rank_fusion'
|
rrf = 'reciprocal_rank_fusion'
|
||||||
node_distance = 'node_distance'
|
node_distance = 'node_distance'
|
||||||
|
episode_mentions = 'episode_mentions'
|
||||||
|
|
||||||
|
|
||||||
class CommunityReranker(Enum):
|
class CommunityReranker(Enum):
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,14 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over edges with episode mention reranking
|
||||||
|
EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
|
||||||
|
edge_config=EdgeSearchConfig(
|
||||||
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
|
reranker=EdgeReranker.episode_mentions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# performs a hybrid search over nodes with rrf reranking
|
# performs a hybrid search over nodes with rrf reranking
|
||||||
NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
node_config=NodeSearchConfig(
|
node_config=NodeSearchConfig(
|
||||||
|
|
@ -75,6 +83,14 @@ NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over nodes with episode mentions reranking
|
||||||
|
NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
|
||||||
|
node_config=NodeSearchConfig(
|
||||||
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||||
|
reranker=NodeReranker.episode_mentions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# performs a hybrid search over communities with rrf reranking
|
# performs a hybrid search over communities with rrf reranking
|
||||||
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
community_config=CommunitySearchConfig(
|
community_config=CommunitySearchConfig(
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
|
|
||||||
|
|
||||||
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
async def get_mentioned_nodes(
|
||||||
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
|
|
@ -57,6 +59,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def get_communities_by_nodes(
|
||||||
|
driver: AsyncDriver, nodes: list[EntityNode]
|
||||||
|
) -> list[CommunityNode]:
|
||||||
|
node_uuids = [node.uuid for node in nodes]
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
c.uuid As uuid,
|
||||||
|
c.group_id AS group_id,
|
||||||
|
c.name AS name,
|
||||||
|
c.name_embedding AS name_embedding
|
||||||
|
c.created_at AS created_at,
|
||||||
|
c.summary AS summary
|
||||||
|
""",
|
||||||
|
uuids=node_uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
return communities
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -634,3 +659,34 @@ async def node_distance_reranker(
|
||||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
return sorted_uuids
|
return sorted_uuids
|
||||||
|
|
||||||
|
|
||||||
|
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
||||||
|
# use rrf as a preliminary ranker
|
||||||
|
sorted_uuids = rrf(node_uuids)
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Find the shortest path to center node
|
||||||
|
query = Query("""
|
||||||
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
|
||||||
|
RETURN count(*) AS score
|
||||||
|
""")
|
||||||
|
|
||||||
|
result_scores = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
driver.execute_query(
|
||||||
|
query,
|
||||||
|
node_uuid=uuid,
|
||||||
|
)
|
||||||
|
for uuid in sorted_uuids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for uuid, result in zip(sorted_uuids, result_scores):
|
||||||
|
record = result[0][0]
|
||||||
|
scores[uuid] = record['score']
|
||||||
|
|
||||||
|
# rerank on shortest distance
|
||||||
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
|
return sorted_uuids
|
||||||
|
|
|
||||||
|
|
@ -163,6 +163,8 @@ async def dedupe_extracted_edges(
|
||||||
if edge.uuid in duplicate_uuid_map:
|
if edge.uuid in duplicate_uuid_map:
|
||||||
existing_uuid = duplicate_uuid_map[edge.uuid]
|
existing_uuid = duplicate_uuid_map[edge.uuid]
|
||||||
existing_edge = edge_map[existing_uuid]
|
existing_edge = edge_map[existing_uuid]
|
||||||
|
# Add current episode to the episodes list
|
||||||
|
existing_edge.episodes += edge.episodes
|
||||||
edges.append(existing_edge)
|
edges.append(existing_edge)
|
||||||
else:
|
else:
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue