* ruff action * chore: Update Python version to 3.10 in lint.yml workflow * fix lint and formatting * cleanup
99 lines
2.5 KiB
Python
99 lines
2.5 KiB
Python
import logging
|
|
from datetime import datetime
|
|
from time import time
|
|
|
|
from neo4j import AsyncDriver
|
|
from pydantic import BaseModel
|
|
|
|
from core.edges import Edge
|
|
from core.llm_client.config import EMBEDDING_DIM
|
|
from core.nodes import Node
|
|
from core.search.search_utils import (
|
|
edge_fulltext_search,
|
|
edge_similarity_search,
|
|
get_mentioned_nodes,
|
|
rrf,
|
|
)
|
|
from core.utils import retrieve_episodes
|
|
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SearchConfig(BaseModel):
|
|
num_results: int = 10
|
|
num_episodes: int = EPISODE_WINDOW_LEN
|
|
similarity_search: str = 'cosine'
|
|
text_search: str = 'BM25'
|
|
reranker: str = 'rrf'
|
|
|
|
|
|
async def hybrid_search(
|
|
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
|
) -> dict[str, [Node | Edge]]:
|
|
start = time()
|
|
|
|
episodes = []
|
|
nodes = []
|
|
edges = []
|
|
|
|
search_results = []
|
|
|
|
if config.num_episodes > 0:
|
|
episodes.extend(await retrieve_episodes(driver, timestamp))
|
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
|
|
|
if config.text_search == 'BM25':
|
|
text_search = await edge_fulltext_search(query, driver)
|
|
search_results.append(text_search)
|
|
|
|
if config.similarity_search == 'cosine':
|
|
query_text = query.replace('\n', ' ')
|
|
search_vector = (
|
|
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
|
|
.data[0]
|
|
.embedding[:EMBEDDING_DIM]
|
|
)
|
|
|
|
similarity_search = await edge_similarity_search(search_vector, driver)
|
|
search_results.append(similarity_search)
|
|
|
|
if len(search_results) == 1:
|
|
edges = search_results[0]
|
|
|
|
elif len(search_results) > 1 and config.reranker != 'rrf':
|
|
logger.exception('Multiple searches enabled without a reranker')
|
|
raise Exception('Multiple searches enabled without a reranker')
|
|
|
|
elif config.reranker == 'rrf':
|
|
edge_uuid_map = {}
|
|
search_result_uuids = []
|
|
|
|
logger.info([[edge.fact for edge in result] for result in search_results])
|
|
|
|
for result in search_results:
|
|
result_uuids = []
|
|
for edge in result:
|
|
result_uuids.append(edge.uuid)
|
|
edge_uuid_map[edge.uuid] = edge
|
|
|
|
search_result_uuids.append(result_uuids)
|
|
|
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
|
|
reranked_uuids = rrf(search_result_uuids)
|
|
|
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
|
edges.extend(reranked_edges)
|
|
|
|
context = {
|
|
'episodes': episodes,
|
|
'nodes': nodes,
|
|
'edges': edges,
|
|
}
|
|
|
|
end = time()
|
|
|
|
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
|
|
|
return context
|