Search node centering (#45)
* add new search reranker and update search * node distance reranking * format * rebase * no need for enumerate * mypy typing * defaultdict update * rrf prelim ranking
This commit is contained in:
parent
fc4bf3bde2
commit
2d01e5d7b7
3 changed files with 101 additions and 22 deletions
|
|
@ -26,7 +26,7 @@ from neo4j import AsyncGraphDatabase
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import SearchConfig, hybrid_search
|
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
get_relevant_nodes,
|
get_relevant_nodes,
|
||||||
|
|
@ -515,7 +515,7 @@ class Graphiti:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def search(self, query: str, num_results=10):
|
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search on the knowledge graph.
|
Perform a hybrid search on the knowledge graph.
|
||||||
|
|
||||||
|
|
@ -526,6 +526,8 @@ class Graphiti:
|
||||||
----------
|
----------
|
||||||
query : str
|
query : str
|
||||||
The search query string.
|
The search query string.
|
||||||
|
center_node_uuid: str, optional
|
||||||
|
Facts will be reranked based on proximity to this node
|
||||||
num_results : int, optional
|
num_results : int, optional
|
||||||
The maximum number of results to return. Defaults to 10.
|
The maximum number of results to return. Defaults to 10.
|
||||||
|
|
||||||
|
|
@ -543,7 +545,14 @@ class Graphiti:
|
||||||
The search is performed using the current date and time as the reference
|
The search is performed using the current date and time as the reference
|
||||||
point for temporal relevance.
|
point for temporal relevance.
|
||||||
"""
|
"""
|
||||||
search_config = SearchConfig(num_episodes=0, num_results=num_results)
|
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
|
||||||
|
search_config = SearchConfig(
|
||||||
|
num_episodes=0,
|
||||||
|
num_edges=num_results,
|
||||||
|
num_nodes=0,
|
||||||
|
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
||||||
|
reranker=reranker,
|
||||||
|
)
|
||||||
edges = (
|
edges = (
|
||||||
await hybrid_search(
|
await hybrid_search(
|
||||||
self.driver,
|
self.driver,
|
||||||
|
|
@ -551,6 +560,7 @@ class Graphiti:
|
||||||
query,
|
query,
|
||||||
datetime.now(),
|
datetime.now(),
|
||||||
search_config,
|
search_config,
|
||||||
|
center_node_uuid,
|
||||||
)
|
)
|
||||||
).edges
|
).edges
|
||||||
|
|
||||||
|
|
@ -558,7 +568,13 @@ class Graphiti:
|
||||||
|
|
||||||
return facts
|
return facts
|
||||||
|
|
||||||
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
|
async def _search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
timestamp: datetime,
|
||||||
|
config: SearchConfig,
|
||||||
|
center_node_uuid: str | None = None,
|
||||||
|
):
|
||||||
return await hybrid_search(
|
return await hybrid_search(
|
||||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config
|
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
@ -28,6 +29,7 @@ from graphiti_core.search.search_utils import (
|
||||||
edge_fulltext_search,
|
edge_fulltext_search,
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
get_mentioned_nodes,
|
get_mentioned_nodes,
|
||||||
|
node_distance_reranker,
|
||||||
rrf,
|
rrf,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils import retrieve_episodes
|
from graphiti_core.utils import retrieve_episodes
|
||||||
|
|
@ -36,12 +38,22 @@ from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchMethod(Enum):
|
||||||
|
cosine_similarity = 'cosine_similarity'
|
||||||
|
bm25 = 'bm25'
|
||||||
|
|
||||||
|
|
||||||
|
class Reranker(Enum):
|
||||||
|
rrf = 'reciprocal_rank_fusion'
|
||||||
|
node_distance = 'node_distance'
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(BaseModel):
|
class SearchConfig(BaseModel):
|
||||||
num_results: int = 10
|
num_edges: int = 10
|
||||||
|
num_nodes: int = 10
|
||||||
num_episodes: int = EPISODE_WINDOW_LEN
|
num_episodes: int = EPISODE_WINDOW_LEN
|
||||||
similarity_search: str = 'cosine'
|
search_methods: list[SearchMethod]
|
||||||
text_search: str = 'BM25'
|
reranker: Reranker | None
|
||||||
reranker: str = 'rrf'
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResults(BaseModel):
|
class SearchResults(BaseModel):
|
||||||
|
|
@ -51,7 +63,12 @@ class SearchResults(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_search(
|
async def hybrid_search(
|
||||||
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
driver: AsyncDriver,
|
||||||
|
embedder,
|
||||||
|
query: str,
|
||||||
|
timestamp: datetime,
|
||||||
|
config: SearchConfig,
|
||||||
|
center_node_uuid: str | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -65,11 +82,11 @@ async def hybrid_search(
|
||||||
episodes.extend(await retrieve_episodes(driver, timestamp))
|
episodes.extend(await retrieve_episodes(driver, timestamp))
|
||||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||||
|
|
||||||
if config.text_search == 'BM25':
|
if SearchMethod.bm25 in config.search_methods:
|
||||||
text_search = await edge_fulltext_search(query, driver)
|
text_search = await edge_fulltext_search(query, driver)
|
||||||
search_results.append(text_search)
|
search_results.append(text_search)
|
||||||
|
|
||||||
if config.similarity_search == 'cosine':
|
if SearchMethod.cosine_similarity in config.search_methods:
|
||||||
query_text = query.replace('\n', ' ')
|
query_text = query.replace('\n', ' ')
|
||||||
search_vector = (
|
search_vector = (
|
||||||
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
|
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
|
||||||
|
|
@ -80,19 +97,14 @@ async def hybrid_search(
|
||||||
similarity_search = await edge_similarity_search(search_vector, driver)
|
similarity_search = await edge_similarity_search(search_vector, driver)
|
||||||
search_results.append(similarity_search)
|
search_results.append(similarity_search)
|
||||||
|
|
||||||
if len(search_results) == 1:
|
if len(search_results) > 1 and config.reranker is None:
|
||||||
edges = search_results[0]
|
|
||||||
|
|
||||||
elif len(search_results) > 1 and config.reranker != 'rrf':
|
|
||||||
logger.exception('Multiple searches enabled without a reranker')
|
logger.exception('Multiple searches enabled without a reranker')
|
||||||
raise Exception('Multiple searches enabled without a reranker')
|
raise Exception('Multiple searches enabled without a reranker')
|
||||||
|
|
||||||
elif config.reranker == 'rrf':
|
else:
|
||||||
edge_uuid_map = {}
|
edge_uuid_map = {}
|
||||||
search_result_uuids = []
|
search_result_uuids = []
|
||||||
|
|
||||||
logger.info([[edge.fact for edge in result] for result in search_results])
|
|
||||||
|
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
result_uuids = []
|
result_uuids = []
|
||||||
for edge in result:
|
for edge in result:
|
||||||
|
|
@ -103,12 +115,23 @@ async def hybrid_search(
|
||||||
|
|
||||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||||
|
|
||||||
reranked_uuids = rrf(search_result_uuids)
|
reranked_uuids: list[str] = []
|
||||||
|
if config.reranker == Reranker.rrf:
|
||||||
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
elif config.reranker == Reranker.node_distance:
|
||||||
|
if center_node_uuid is None:
|
||||||
|
logger.exception('No center node provided for Node Distance reranker')
|
||||||
|
raise Exception('No center node provided for Node Distance reranker')
|
||||||
|
reranked_uuids = await node_distance_reranker(
|
||||||
|
driver, search_result_uuids, center_node_uuid
|
||||||
|
)
|
||||||
|
|
||||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
edges.extend(reranked_edges)
|
edges.extend(reranked_edges)
|
||||||
|
|
||||||
context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
|
context = SearchResults(
|
||||||
|
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
|
||||||
|
)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -333,7 +333,7 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
# takes in a list of rankings of uuids
|
# takes in a list of rankings of uuids
|
||||||
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
scores: dict[str, int] = defaultdict(int)
|
scores: dict[str, float] = defaultdict(float)
|
||||||
for result in results:
|
for result in results:
|
||||||
for i, uuid in enumerate(result):
|
for i, uuid in enumerate(result):
|
||||||
scores[uuid] += 1 / (i + rank_const)
|
scores[uuid] += 1 / (i + rank_const)
|
||||||
|
|
@ -344,3 +344,43 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
sorted_uuids = [term[0] for term in scored_uuids]
|
sorted_uuids = [term[0] for term in scored_uuids]
|
||||||
|
|
||||||
return sorted_uuids
|
return sorted_uuids
|
||||||
|
|
||||||
|
|
||||||
|
async def node_distance_reranker(
|
||||||
|
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
||||||
|
) -> list[str]:
|
||||||
|
# use rrf as a preliminary ranker
|
||||||
|
sorted_uuids = rrf(results)
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
|
for uuid in sorted_uuids:
|
||||||
|
# Find shortest path to center node
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
||||||
|
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
|
||||||
|
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
|
||||||
|
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
||||||
|
""",
|
||||||
|
edge_uuid=uuid,
|
||||||
|
center_uuid=center_node_uuid,
|
||||||
|
)
|
||||||
|
distance = 0.01
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
if (
|
||||||
|
record['source_uuid'] == center_node_uuid
|
||||||
|
or record['target_uuid'] == center_node_uuid
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
distance = record['score']
|
||||||
|
|
||||||
|
if uuid in scores:
|
||||||
|
scores[uuid] = min(1 / distance, scores[uuid])
|
||||||
|
else:
|
||||||
|
scores[uuid] = 1 / distance
|
||||||
|
|
||||||
|
# rerank on shortest distance
|
||||||
|
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
|
return sorted_uuids
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue