* fix(search): 🐛 Search across null group_ids
* chore: Version bump
* chore: Set group_ids to none if it's an empty list
* fix: Check for group ids being a list before setting it to None if empty
* fix check
* chore: Simplify group_ids check
* chore: Simplify the check further
242 lines
7.9 KiB
Python
242 lines
7.9 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import logging
|
|
from time import time
|
|
|
|
from neo4j import AsyncDriver
|
|
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.errors import SearchRerankerError
|
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
from graphiti_core.search.search_config import (
|
|
DEFAULT_SEARCH_LIMIT,
|
|
CommunityReranker,
|
|
CommunitySearchConfig,
|
|
CommunitySearchMethod,
|
|
EdgeReranker,
|
|
EdgeSearchConfig,
|
|
EdgeSearchMethod,
|
|
NodeReranker,
|
|
NodeSearchConfig,
|
|
NodeSearchMethod,
|
|
SearchConfig,
|
|
SearchResults,
|
|
)
|
|
from graphiti_core.search.search_utils import (
|
|
community_fulltext_search,
|
|
community_similarity_search,
|
|
edge_fulltext_search,
|
|
edge_similarity_search,
|
|
node_distance_reranker,
|
|
node_fulltext_search,
|
|
node_similarity_search,
|
|
rrf,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def search(
|
|
driver: AsyncDriver,
|
|
embedder,
|
|
query: str,
|
|
group_ids: list[str | None] | None,
|
|
config: SearchConfig,
|
|
center_node_uuid: str | None = None,
|
|
) -> SearchResults:
|
|
start = time()
|
|
query = query.replace('\n', ' ')
|
|
# if group_ids is empty, set it to None
|
|
group_ids = group_ids if group_ids else None
|
|
edges = (
|
|
await edge_search(
|
|
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
|
|
)
|
|
if config.edge_config is not None
|
|
else []
|
|
)
|
|
nodes = (
|
|
await node_search(
|
|
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
|
|
)
|
|
if config.node_config is not None
|
|
else []
|
|
)
|
|
communities = (
|
|
await community_search(
|
|
driver, embedder, query, group_ids, config.community_config, config.limit
|
|
)
|
|
if config.community_config is not None
|
|
else []
|
|
)
|
|
|
|
results = SearchResults(
|
|
edges=edges[: config.limit],
|
|
nodes=nodes[: config.limit],
|
|
communities=communities[: config.limit],
|
|
)
|
|
|
|
end = time()
|
|
|
|
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
|
|
|
return results
|
|
|
|
|
|
async def edge_search(
|
|
driver: AsyncDriver,
|
|
embedder,
|
|
query: str,
|
|
group_ids: list[str | None] | None,
|
|
config: EdgeSearchConfig,
|
|
center_node_uuid: str | None = None,
|
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
) -> list[EntityEdge]:
|
|
search_results: list[list[EntityEdge]] = []
|
|
|
|
if EdgeSearchMethod.bm25 in config.search_methods:
|
|
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
|
|
search_results.append(text_search)
|
|
|
|
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
|
search_vector = (
|
|
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
|
.data[0]
|
|
.embedding[:EMBEDDING_DIM]
|
|
)
|
|
|
|
similarity_search = await edge_similarity_search(
|
|
driver, search_vector, None, None, group_ids, 2 * limit
|
|
)
|
|
search_results.append(similarity_search)
|
|
|
|
if len(search_results) > 1 and config.reranker is None:
|
|
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
|
|
|
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
|
|
reranked_uuids: list[str] = []
|
|
if config.reranker == EdgeReranker.rrf:
|
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
|
|
reranked_uuids = rrf(search_result_uuids)
|
|
elif config.reranker == EdgeReranker.node_distance:
|
|
if center_node_uuid is None:
|
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
|
|
source_to_edge_uuid_map = {
|
|
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
|
|
}
|
|
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
|
|
|
|
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
|
|
|
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
|
|
|
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
|
|
|
return reranked_edges
|
|
|
|
|
|
async def node_search(
|
|
driver: AsyncDriver,
|
|
embedder,
|
|
query: str,
|
|
group_ids: list[str | None] | None,
|
|
config: NodeSearchConfig,
|
|
center_node_uuid: str | None = None,
|
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
) -> list[EntityNode]:
|
|
search_results: list[list[EntityNode]] = []
|
|
|
|
if NodeSearchMethod.bm25 in config.search_methods:
|
|
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
|
|
search_results.append(text_search)
|
|
|
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
|
search_vector = (
|
|
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
|
.data[0]
|
|
.embedding[:EMBEDDING_DIM]
|
|
)
|
|
|
|
similarity_search = await node_similarity_search(
|
|
driver, search_vector, group_ids, 2 * limit
|
|
)
|
|
search_results.append(similarity_search)
|
|
|
|
if len(search_results) > 1 and config.reranker is None:
|
|
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
|
|
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
|
|
reranked_uuids: list[str] = []
|
|
if config.reranker == NodeReranker.rrf:
|
|
reranked_uuids = rrf(search_result_uuids)
|
|
elif config.reranker == NodeReranker.node_distance:
|
|
if center_node_uuid is None:
|
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
|
|
|
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
|
|
return reranked_nodes
|
|
|
|
|
|
async def community_search(
|
|
driver: AsyncDriver,
|
|
embedder,
|
|
query: str,
|
|
group_ids: list[str | None] | None,
|
|
config: CommunitySearchConfig,
|
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
) -> list[CommunityNode]:
|
|
search_results: list[list[CommunityNode]] = []
|
|
|
|
if CommunitySearchMethod.bm25 in config.search_methods:
|
|
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
|
|
search_results.append(text_search)
|
|
|
|
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
|
search_vector = (
|
|
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
|
.data[0]
|
|
.embedding[:EMBEDDING_DIM]
|
|
)
|
|
|
|
similarity_search = await community_similarity_search(
|
|
driver, search_vector, group_ids, 2 * limit
|
|
)
|
|
search_results.append(similarity_search)
|
|
|
|
if len(search_results) > 1 and config.reranker is None:
|
|
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
|
|
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
|
community_uuid_map = {
|
|
community.uuid: community for result in search_results for community in result
|
|
}
|
|
|
|
reranked_uuids: list[str] = []
|
|
if config.reranker == CommunityReranker.rrf:
|
|
reranked_uuids = rrf(search_result_uuids)
|
|
|
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
|
|
return reranked_communities
|