Search node centering (#45)

* add new search reranker and update search

* node distance reranking

* format

* rebase

* no need for enumerate

* mypy typing

* defaultdict update

* rrf prelim ranking
This commit is contained in:
Preston Rasmussen 2024-08-26 18:34:57 -04:00 committed by GitHub
parent fc4bf3bde2
commit 2d01e5d7b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 22 deletions

View file

@ -26,7 +26,7 @@ from neo4j import AsyncGraphDatabase
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, hybrid_search from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search_utils import ( from graphiti_core.search.search_utils import (
get_relevant_edges, get_relevant_edges,
get_relevant_nodes, get_relevant_nodes,
@ -515,7 +515,7 @@ class Graphiti:
except Exception as e: except Exception as e:
raise e raise e
async def search(self, query: str, num_results=10): async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
""" """
Perform a hybrid search on the knowledge graph. Perform a hybrid search on the knowledge graph.
@ -526,6 +526,8 @@ class Graphiti:
---------- ----------
query : str query : str
The search query string. The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
num_results : int, optional num_results : int, optional
The maximum number of results to return. Defaults to 10. The maximum number of results to return. Defaults to 10.
@ -543,7 +545,14 @@ class Graphiti:
The search is performed using the current date and time as the reference The search is performed using the current date and time as the reference
point for temporal relevance. point for temporal relevance.
""" """
search_config = SearchConfig(num_episodes=0, num_results=num_results) reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
search_config = SearchConfig(
num_episodes=0,
num_edges=num_results,
num_nodes=0,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
edges = ( edges = (
await hybrid_search( await hybrid_search(
self.driver, self.driver,
@ -551,6 +560,7 @@ class Graphiti:
query, query,
datetime.now(), datetime.now(),
search_config, search_config,
center_node_uuid,
) )
).edges ).edges
@ -558,7 +568,13 @@ class Graphiti:
return facts return facts
async def _search(self, query: str, timestamp: datetime, config: SearchConfig): async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search( return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
) )

View file

@ -16,6 +16,7 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from enum import Enum
from time import time from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
@ -28,6 +29,7 @@ from graphiti_core.search.search_utils import (
edge_fulltext_search, edge_fulltext_search,
edge_similarity_search, edge_similarity_search,
get_mentioned_nodes, get_mentioned_nodes,
node_distance_reranker,
rrf, rrf,
) )
from graphiti_core.utils import retrieve_episodes from graphiti_core.utils import retrieve_episodes
@ -36,12 +38,22 @@ from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
class Reranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
class SearchConfig(BaseModel): class SearchConfig(BaseModel):
num_results: int = 10 num_edges: int = 10
num_nodes: int = 10
num_episodes: int = EPISODE_WINDOW_LEN num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = 'cosine' search_methods: list[SearchMethod]
text_search: str = 'BM25' reranker: Reranker | None
reranker: str = 'rrf'
class SearchResults(BaseModel): class SearchResults(BaseModel):
@ -51,7 +63,12 @@ class SearchResults(BaseModel):
async def hybrid_search( async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults: ) -> SearchResults:
start = time() start = time()
@ -65,11 +82,11 @@ async def hybrid_search(
episodes.extend(await retrieve_episodes(driver, timestamp)) episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes)) nodes.extend(await get_mentioned_nodes(driver, episodes))
if config.text_search == 'BM25': if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver) text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search) search_results.append(text_search)
if config.similarity_search == 'cosine': if SearchMethod.cosine_similarity in config.search_methods:
query_text = query.replace('\n', ' ') query_text = query.replace('\n', ' ')
search_vector = ( search_vector = (
(await embedder.create(input=[query_text], model='text-embedding-3-small')) (await embedder.create(input=[query_text], model='text-embedding-3-small'))
@ -80,19 +97,14 @@ async def hybrid_search(
similarity_search = await edge_similarity_search(search_vector, driver) similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search) search_results.append(similarity_search)
if len(search_results) == 1: if len(search_results) > 1 and config.reranker is None:
edges = search_results[0]
elif len(search_results) > 1 and config.reranker != 'rrf':
logger.exception('Multiple searches enabled without a reranker') logger.exception('Multiple searches enabled without a reranker')
raise Exception('Multiple searches enabled without a reranker') raise Exception('Multiple searches enabled without a reranker')
elif config.reranker == 'rrf': else:
edge_uuid_map = {} edge_uuid_map = {}
search_result_uuids = [] search_result_uuids = []
logger.info([[edge.fact for edge in result] for result in search_results])
for result in search_results: for result in search_results:
result_uuids = [] result_uuids = []
for edge in result: for edge in result:
@ -103,12 +115,23 @@ async def hybrid_search(
search_result_uuids = [[edge.uuid for edge in result] for result in search_results] search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids) reranked_uuids: list[str] = []
if config.reranker == Reranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == Reranker.node_distance:
if center_node_uuid is None:
logger.exception('No center node provided for Node Distance reranker')
raise Exception('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, search_result_uuids, center_node_uuid
)
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges) edges.extend(reranked_edges)
context = SearchResults(episodes=episodes, nodes=nodes, edges=edges) context = SearchResults(
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
)
end = time() end = time()

View file

@ -333,7 +333,7 @@ async def get_relevant_edges(
# takes in a list of rankings of uuids # takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]: def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int) scores: dict[str, float] = defaultdict(float)
for result in results: for result in results:
for i, uuid in enumerate(result): for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const) scores[uuid] += 1 / (i + rank_const)
@ -344,3 +344,43 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
sorted_uuids = [term[0] for term in scored_uuids] sorted_uuids = [term[0] for term in scored_uuids]
return sorted_uuids return sorted_uuids
async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
scores: dict[str, float] = {}
for uuid in sorted_uuids:
# Find shortest path to center node
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01
for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']
if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
else:
scores[uuid] = 1 / distance
# rerank on shortest distance
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids