graphiti/core/search/search.py
Preston Rasmussen 63b9790026
search updates (#14)
* search updates

* test updates

* add opinionated search

* update
2024-08-22 14:26:26 -04:00

104 lines
2.9 KiB
Python

import asyncio
import logging
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from pydantic import BaseModel
from core.edges import EntityEdge, Edge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.search.search_utils import (
edge_similarity_search,
edge_fulltext_search,
get_mentioned_nodes,
rrf,
)
from core.utils import retrieve_episodes
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
logger = logging.getLogger(__name__)
class SearchConfig(BaseModel):
num_results: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = "cosine"
text_search: str = "BM25"
reranker: str = "rrf"
async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
start = time()
episodes = []
nodes = []
edges = []
search_results = []
if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))
if config.text_search == "BM25":
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)
if config.similarity_search == "cosine":
query_text = query.replace("\n", " ")
search_vector = (
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)
if len(search_results) == 1:
edges = search_results[0]
elif len(search_results) > 1 and not config.reranker == "rrf":
logger.exception("Multiple searches enabled without a reranker")
raise Exception("Multiple searches enabled without a reranker")
elif config.reranker == "rrf":
edge_uuid_map = {}
search_result_uuids = []
logger.info([[edge.fact for edge in result] for result in search_results])
for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge
search_result_uuids.append(result_uuids)
search_result_uuids = [
[edge.uuid for edge in result] for result in search_results
]
reranked_uuids = rrf(search_result_uuids)
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)
context = {
"episodes": episodes,
"nodes": nodes,
"edges": edges,
}
end = time()
logger.info(
f"search returned context for query {query} in {(end - start) * 1000} ms"
)
return context