graphiti/graphiti_core/search/search.py
Pavlo Paliychuk 19a6ebc6fe
Fix groupless search (#118)
* fix(search): 🐛 Search across null group_ids

* chore: Version bump

* chore: Set group_ids to none if it's an empty list

* fix: Check for group ids being a list before setting it to None if empty

* fix check

* chore: Simplify group_ids check

* chore: Simplify the check further
2024-09-16 16:23:07 -04:00

242 lines
7.9 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from time import time
from neo4j import AsyncDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.errors import SearchRerankerError
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
SearchResults,
)
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_fulltext_search,
edge_similarity_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
rrf,
)
logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()
query = query.replace('\n', ' ')
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
edges = (
await edge_search(
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
)
if config.edge_config is not None
else []
)
nodes = (
await node_search(
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
)
if config.node_config is not None
else []
)
communities = (
await community_search(
driver, embedder, query, group_ids, config.community_config, config.limit
)
if config.community_config is not None
else []
)
results = SearchResults(
edges=edges[: config.limit],
nodes=nodes[: config.limit],
communities=communities[: config.limit],
)
end = time()
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
return results
async def edge_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
config: EdgeSearchConfig,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
search_results: list[list[EntityEdge]] = []
if EdgeSearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
search_results.append(text_search)
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await edge_similarity_search(
driver, search_vector, None, None, group_ids, 2 * limit
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = []
if config.reranker == EdgeReranker.rrf:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
source_to_edge_uuid_map = {
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
}
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_edges
async def node_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
config: NodeSearchConfig,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
search_results: list[list[EntityNode]] = []
if NodeSearchMethod.bm25 in config.search_methods:
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
search_results.append(text_search)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await node_similarity_search(
driver, search_vector, group_ids, 2 * limit
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
search_result_uuids = [[node.uuid for node in result] for result in search_results]
node_uuid_map = {node.uuid: node for result in search_results for node in result}
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes
async def community_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
config: CommunitySearchConfig,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
search_results: list[list[CommunityNode]] = []
if CommunitySearchMethod.bm25 in config.search_methods:
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
search_results.append(text_search)
if CommunitySearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await community_similarity_search(
driver, search_vector, group_ids, 2 * limit
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
search_result_uuids = [[community.uuid for community in result] for result in search_results]
community_uuid_map = {
community.uuid: community for result in search_results for community in result
}
reranked_uuids: list[str] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities