graphiti/core/search/search.py
Daniel Chalef 73ec0146ff
ruff action (#17)
* ruff action

* chore: Update Python version to 3.10 in lint.yml workflow

* fix lint and formatting

* cleanup
2024-08-22 13:06:42 -07:00

99 lines
2.5 KiB
Python

import logging
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from pydantic import BaseModel
from core.edges import Edge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.search.search_utils import (
edge_fulltext_search,
edge_similarity_search,
get_mentioned_nodes,
rrf,
)
from core.utils import retrieve_episodes
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
logger = logging.getLogger(__name__)
class SearchConfig(BaseModel):
num_results: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = 'cosine'
text_search: str = 'BM25'
reranker: str = 'rrf'
async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
start = time()
episodes = []
nodes = []
edges = []
search_results = []
if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))
if config.text_search == 'BM25':
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)
if config.similarity_search == 'cosine':
query_text = query.replace('\n', ' ')
search_vector = (
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)
if len(search_results) == 1:
edges = search_results[0]
elif len(search_results) > 1 and config.reranker != 'rrf':
logger.exception('Multiple searches enabled without a reranker')
raise Exception('Multiple searches enabled without a reranker')
elif config.reranker == 'rrf':
edge_uuid_map = {}
search_result_uuids = []
logger.info([[edge.fact for edge in result] for result in search_results])
for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge
search_result_uuids.append(result_uuids)
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids)
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)
context = {
'episodes': episodes,
'nodes': nodes,
'edges': edges,
}
end = time()
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
return context