Search refactor + Community search (#111)
* WIP * WIP * WIP * community search * WIP * WIP * integration tested * tests * tests * mypy * mypy * format
This commit is contained in:
parent
e4ee8d62fa
commit
d7c20c1f59
13 changed files with 780 additions and 329 deletions
|
|
@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
|
|||
messages = parse_podcast_messages()
|
||||
|
||||
if not use_bulk:
|
||||
for i, message in enumerate(messages[3:130]):
|
||||
for i, message in enumerate(messages[3:20]):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
|
|
|
|||
|
|
@ -1,3 +1,20 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
class GraphitiError(Exception):
|
||||
"""Base exception class for Graphiti Core."""
|
||||
|
||||
|
|
@ -16,3 +33,11 @@ class NodeNotFoundError(GraphitiError):
|
|||
def __init__(self, uuid: str):
|
||||
self.message = f'node {uuid} not found'
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class SearchRerankerError(GraphitiError):
|
||||
"""Raised when a node is not found."""
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.message = text
|
||||
super().__init__(self.message)
|
||||
|
|
|
|||
|
|
@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
|
|||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.llm_client.utils import generate_embedding
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
|
||||
from graphiti_core.search.search import SearchConfig, search
|
||||
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
||||
from graphiti_core.search.search_config_recipes import (
|
||||
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
|
||||
EDGE_HYBRID_SEARCH_RRF,
|
||||
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
||||
NODE_HYBRID_SEARCH_RRF,
|
||||
)
|
||||
from graphiti_core.search.search_utils import (
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
get_relevant_edges,
|
||||
get_relevant_nodes,
|
||||
hybrid_node_search,
|
||||
)
|
||||
from graphiti_core.utils import (
|
||||
build_episodic_edges,
|
||||
|
|
@ -548,7 +553,7 @@ class Graphiti:
|
|||
query: str,
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
num_results=10,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
):
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
|
|
@ -564,7 +569,7 @@ class Graphiti:
|
|||
Facts will be reranked based on proximity to this node
|
||||
group_ids : list[str | None] | None, optional
|
||||
The graph partitions to return data from.
|
||||
num_results : int, optional
|
||||
limit : int, optional
|
||||
The maximum number of results to return. Defaults to 10.
|
||||
|
||||
Returns
|
||||
|
|
@ -581,21 +586,17 @@ class Graphiti:
|
|||
The search is performed using the current date and time as the reference
|
||||
point for temporal relevance.
|
||||
"""
|
||||
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
|
||||
search_config = SearchConfig(
|
||||
num_episodes=0,
|
||||
num_edges=num_results,
|
||||
num_nodes=0,
|
||||
group_ids=group_ids,
|
||||
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
||||
reranker=reranker,
|
||||
search_config = (
|
||||
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
||||
)
|
||||
search_config.limit = num_results
|
||||
|
||||
edges = (
|
||||
await hybrid_search(
|
||||
await search(
|
||||
self.driver,
|
||||
self.llm_client.get_embedder(),
|
||||
query,
|
||||
datetime.now(),
|
||||
group_ids,
|
||||
search_config,
|
||||
center_node_uuid,
|
||||
)
|
||||
|
|
@ -606,19 +607,20 @@ class Graphiti:
|
|||
async def _search(
|
||||
self,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
group_ids: list[str | None] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
):
|
||||
return await hybrid_search(
|
||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||
) -> SearchResults:
|
||||
return await search(
|
||||
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
|
||||
)
|
||||
|
||||
async def get_nodes_by_query(
|
||||
self,
|
||||
query: str,
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
limit: int = DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Retrieve nodes from the graph database based on a text query.
|
||||
|
|
@ -629,7 +631,9 @@ class Graphiti:
|
|||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The text query to search for in the graph.
|
||||
The text query to search for in the graph
|
||||
center_node_uuid: str, optional
|
||||
Facts will be reranked based on proximity to this node.
|
||||
group_ids : list[str | None] | None, optional
|
||||
The graph partitions to return data from.
|
||||
limit : int | None, optional
|
||||
|
|
@ -655,8 +659,12 @@ class Graphiti:
|
|||
If not specified, a default limit (defined in the search functions) will be used.
|
||||
"""
|
||||
embedder = self.llm_client.get_embedder()
|
||||
query_embedding = await generate_embedding(embedder, query)
|
||||
relevant_nodes = await hybrid_node_search(
|
||||
[query], [query_embedding], self.driver, group_ids, limit
|
||||
search_config = (
|
||||
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
|
||||
)
|
||||
return relevant_nodes
|
||||
search_config.limit = limit
|
||||
|
||||
nodes = (
|
||||
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
||||
).nodes
|
||||
return nodes
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from neo4j import time as neo4j_time
|
||||
|
|
|
|||
|
|
@ -1,3 +1,20 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Exception raised when the rate limit is exceeded."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import typing
|
||||
from time import time
|
||||
|
|
@ -17,6 +33,6 @@ async def generate_embedding(
|
|||
embedding = embedding[:EMBEDDING_DIM]
|
||||
|
||||
end = time()
|
||||
logger.debug(f'embedded text of length {len(text)} in {end-start} ms')
|
||||
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
|
||||
|
||||
return embedding
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
|
|
|||
|
|
@ -15,131 +15,227 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.search.search_config import (
|
||||
DEFAULT_SEARCH_LIMIT,
|
||||
CommunityReranker,
|
||||
CommunitySearchConfig,
|
||||
CommunitySearchMethod,
|
||||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
NodeReranker,
|
||||
NodeSearchConfig,
|
||||
NodeSearchMethod,
|
||||
SearchConfig,
|
||||
SearchResults,
|
||||
)
|
||||
from graphiti_core.search.search_utils import (
|
||||
community_fulltext_search,
|
||||
community_similarity_search,
|
||||
edge_fulltext_search,
|
||||
edge_similarity_search,
|
||||
get_mentioned_nodes,
|
||||
node_distance_reranker,
|
||||
node_fulltext_search,
|
||||
node_similarity_search,
|
||||
rrf,
|
||||
)
|
||||
from graphiti_core.utils import retrieve_episodes
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
|
||||
|
||||
class Reranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
node_distance = 'node_distance'
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
num_edges: int = Field(default=10)
|
||||
num_nodes: int = Field(default=10)
|
||||
num_episodes: int = EPISODE_WINDOW_LEN
|
||||
group_ids: list[str | None] | None
|
||||
search_methods: list[SearchMethod]
|
||||
reranker: Reranker | None
|
||||
|
||||
|
||||
class SearchResults(BaseModel):
|
||||
episodes: list[EpisodicNode]
|
||||
nodes: list[EntityNode]
|
||||
edges: list[EntityEdge]
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
async def search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
group_ids: list[str | None] | None,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
query = query.replace('\n', ' ')
|
||||
|
||||
episodes = []
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
search_results = []
|
||||
|
||||
if config.num_episodes > 0:
|
||||
episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
|
||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||
|
||||
if SearchMethod.bm25 in config.search_methods:
|
||||
text_search = await edge_fulltext_search(
|
||||
driver, query, None, None, config.group_ids, 2 * config.num_edges
|
||||
edges = (
|
||||
await edge_search(
|
||||
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
|
||||
)
|
||||
search_results.append(text_search)
|
||||
|
||||
if SearchMethod.cosine_similarity in config.search_methods:
|
||||
query_text = query.replace('\n', ' ')
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
if config.edge_config is not None
|
||||
else []
|
||||
)
|
||||
nodes = (
|
||||
await node_search(
|
||||
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
|
||||
)
|
||||
|
||||
similarity_search = await edge_similarity_search(
|
||||
driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
|
||||
if config.node_config is not None
|
||||
else []
|
||||
)
|
||||
communities = (
|
||||
await community_search(
|
||||
driver, embedder, query, group_ids, config.community_config, config.limit
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
if config.community_config is not None
|
||||
else []
|
||||
)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
logger.exception('Multiple searches enabled without a reranker')
|
||||
raise Exception('Multiple searches enabled without a reranker')
|
||||
|
||||
else:
|
||||
edge_uuid_map = {}
|
||||
search_result_uuids = []
|
||||
|
||||
for result in search_results:
|
||||
result_uuids = []
|
||||
for edge in result:
|
||||
result_uuids.append(edge.uuid)
|
||||
edge_uuid_map[edge.uuid] = edge
|
||||
|
||||
search_result_uuids.append(result_uuids)
|
||||
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == Reranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
elif config.reranker == Reranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
logger.exception('No center node provided for Node Distance reranker')
|
||||
raise Exception('No center node provided for Node Distance reranker')
|
||||
reranked_uuids = await node_distance_reranker(
|
||||
driver, search_result_uuids, center_node_uuid
|
||||
)
|
||||
|
||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
edges.extend(reranked_edges)
|
||||
|
||||
context = SearchResults(
|
||||
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
|
||||
results = SearchResults(
|
||||
edges=edges[: config.limit],
|
||||
nodes=nodes[: config.limit],
|
||||
communities=communities[: config.limit],
|
||||
)
|
||||
|
||||
end = time()
|
||||
|
||||
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
||||
|
||||
return context
|
||||
return results
|
||||
|
||||
|
||||
async def edge_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None,
|
||||
config: EdgeSearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
search_results: list[list[EntityEdge]] = []
|
||||
|
||||
if EdgeSearchMethod.bm25 in config.search_methods:
|
||||
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
|
||||
search_results.append(text_search)
|
||||
|
||||
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
||||
similarity_search = await edge_similarity_search(
|
||||
driver, search_vector, None, None, group_ids, 2 * limit
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
|
||||
|
||||
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == EdgeReranker.rrf:
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
elif config.reranker == EdgeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
|
||||
source_to_edge_uuid_map = {
|
||||
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
|
||||
}
|
||||
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
|
||||
|
||||
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
||||
|
||||
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
|
||||
|
||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_edges
|
||||
|
||||
|
||||
async def node_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None,
|
||||
config: NodeSearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
search_results: list[list[EntityNode]] = []
|
||||
|
||||
if NodeSearchMethod.bm25 in config.search_methods:
|
||||
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
|
||||
search_results.append(text_search)
|
||||
|
||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
||||
similarity_search = await node_similarity_search(
|
||||
driver, search_vector, group_ids, 2 * limit
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
||||
|
||||
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
||||
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == NodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
elif config.reranker == NodeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
|
||||
|
||||
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_nodes
|
||||
|
||||
|
||||
async def community_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None,
|
||||
config: CommunitySearchConfig,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
search_results: list[list[CommunityNode]] = []
|
||||
|
||||
if CommunitySearchMethod.bm25 in config.search_methods:
|
||||
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
|
||||
search_results.append(text_search)
|
||||
|
||||
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
||||
similarity_search = await community_similarity_search(
|
||||
driver, search_vector, group_ids, 2 * limit
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
||||
|
||||
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
||||
community_uuid_map = {
|
||||
community.uuid: community for result in search_results for community in result
|
||||
}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == CommunityReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
|
||||
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
return reranked_communities
|
||||
|
|
|
|||
81
graphiti_core/search/search_config.py
Normal file
81
graphiti_core/search/search_config.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 10
|
||||
|
||||
|
||||
class EdgeSearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
|
||||
|
||||
class NodeSearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
|
||||
|
||||
class CommunitySearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
|
||||
|
||||
class EdgeReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
node_distance = 'node_distance'
|
||||
|
||||
|
||||
class NodeReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
node_distance = 'node_distance'
|
||||
|
||||
|
||||
class CommunityReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
|
||||
|
||||
class EdgeSearchConfig(BaseModel):
|
||||
search_methods: list[EdgeSearchMethod]
|
||||
reranker: EdgeReranker | None
|
||||
|
||||
|
||||
class NodeSearchConfig(BaseModel):
|
||||
search_methods: list[NodeSearchMethod]
|
||||
reranker: NodeReranker | None
|
||||
|
||||
|
||||
class CommunitySearchConfig(BaseModel):
|
||||
search_methods: list[CommunitySearchMethod]
|
||||
reranker: CommunityReranker | None
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
edge_config: EdgeSearchConfig | None = Field(default=None)
|
||||
node_config: NodeSearchConfig | None = Field(default=None)
|
||||
community_config: CommunitySearchConfig | None = Field(default=None)
|
||||
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
||||
|
||||
|
||||
class SearchResults(BaseModel):
|
||||
edges: list[EntityEdge]
|
||||
nodes: list[EntityNode]
|
||||
communities: list[CommunityNode]
|
||||
84
graphiti_core/search/search_config_recipes.py
Normal file
84
graphiti_core/search/search_config_recipes.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from graphiti_core.search.search_config import (
|
||||
CommunityReranker,
|
||||
CommunitySearchConfig,
|
||||
CommunitySearchMethod,
|
||||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
NodeReranker,
|
||||
NodeSearchConfig,
|
||||
NodeSearchMethod,
|
||||
SearchConfig,
|
||||
)
|
||||
|
||||
# Performs a hybrid search with rrf reranking over edges, nodes, and communities
|
||||
COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
||||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.rrf,
|
||||
),
|
||||
node_config=NodeSearchConfig(
|
||||
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||
reranker=NodeReranker.rrf,
|
||||
),
|
||||
community_config=CommunitySearchConfig(
|
||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||
reranker=CommunityReranker.rrf,
|
||||
),
|
||||
)
|
||||
|
||||
# performs a hybrid search over edges with rrf reranking
|
||||
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.rrf,
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over edges with node distance reranking
|
||||
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.node_distance,
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over nodes with rrf reranking
|
||||
NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||
node_config=NodeSearchConfig(
|
||||
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||
reranker=NodeReranker.rrf,
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over nodes with node distance reranking
|
||||
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||
node_config=NodeSearchConfig(
|
||||
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||
reranker=NodeReranker.node_distance,
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over communities with rrf reranking
|
||||
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
||||
community_config=CommunitySearchConfig(
|
||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||
reranker=CommunityReranker.rrf,
|
||||
)
|
||||
)
|
||||
|
|
@ -1,3 +1,19 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -7,7 +23,13 @@ from time import time
|
|||
from neo4j import AsyncDriver, Query
|
||||
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
EpisodicNode,
|
||||
get_community_node_from_record,
|
||||
get_entity_node_from_record,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,181 +57,6 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
|||
return nodes
|
||||
|
||||
|
||||
async def edge_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
|
||||
if source_node_uuid is None and target_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
elif source_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
elif target_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def entity_similarity_search(
|
||||
search_vector: list[float],
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
MATCH (n WHERE n.group_id IN $group_ids)
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def entity_fulltext_search(
|
||||
query: str,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
MATCH (n WHERE n.group_id in $group_ids)
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def edge_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
|
|
@ -322,6 +169,247 @@ async def edge_fulltext_search(
|
|||
return edges
|
||||
|
||||
|
||||
async def edge_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
|
||||
if source_node_uuid is None and target_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
elif source_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
elif target_node_uuid is None:
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
""")
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def node_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
MATCH (n WHERE n.group_id in $group_ids)
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def node_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
MATCH (n WHERE n.group_id IN $group_ids)
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def community_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# BM25 search to get top communities
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
YIELD node AS comm, score
|
||||
MATCH (comm WHERE comm.group_id in $group_ids)
|
||||
RETURN
|
||||
comm.uuid AS uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
return communities
|
||||
|
||||
|
||||
async def community_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
||||
YIELD node AS comm, score
|
||||
MATCH (comm WHERE comm.group_id IN $group_ids)
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
return communities
|
||||
|
||||
|
||||
async def hybrid_node_search(
|
||||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
|
|
@ -371,8 +459,8 @@ async def hybrid_node_search(
|
|||
|
||||
results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
|
||||
*[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
|
||||
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
||||
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -490,24 +578,23 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|||
|
||||
|
||||
async def node_distance_reranker(
|
||||
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
||||
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
|
||||
) -> list[str]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(results)
|
||||
sorted_uuids = rrf(node_uuids)
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
|
||||
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
|
||||
RETURN length(p) AS score
|
||||
""")
|
||||
|
||||
path_results = await asyncio.gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
edge_uuid=uuid,
|
||||
node_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
)
|
||||
for uuid in sorted_uuids
|
||||
|
|
@ -518,15 +605,8 @@ async def node_distance_reranker(
|
|||
records = result[0]
|
||||
record = records[0] if len(records) > 0 else None
|
||||
distance: float = record['score'] if record is not None else float('inf')
|
||||
if record is not None and (
|
||||
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
|
||||
):
|
||||
distance = 0
|
||||
|
||||
if uuid in scores:
|
||||
scores[uuid] = min(distance, scores[uuid])
|
||||
else:
|
||||
scores[uuid] = distance
|
||||
distance = 0 if uuid == center_node_uuid else distance
|
||||
scores[uuid] = distance
|
||||
|
||||
# rerank on shortest distance
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from dotenv import load_dotenv
|
|||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
|
@ -81,6 +82,17 @@ async def test_graphiti_init():
|
|||
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
|
||||
|
||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
results = await graphiti._search(
|
||||
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1']
|
||||
)
|
||||
pretty_results = {
|
||||
'edges': [edge.fact for edge in results.edges],
|
||||
'nodes': [node.name for node in results.nodes],
|
||||
'communities': [community.name for community in results.communities],
|
||||
}
|
||||
|
||||
logger.info(pretty_results)
|
||||
graphiti.close()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,11 @@ async def test_hybrid_node_search_deduplication():
|
|||
# Mock the database driver
|
||||
mock_driver = AsyncMock()
|
||||
|
||||
# Mock the entity_fulltext_search and entity_similarity_search functions
|
||||
# Mock the node_fulltext_search and entity_similarity_search functions
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
# Set up mock return values
|
||||
mock_fulltext_search.side_effect = [
|
||||
|
|
@ -47,9 +47,9 @@ async def test_hybrid_node_search_empty_results():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
mock_fulltext_search.return_value = []
|
||||
mock_similarity_search.return_value = []
|
||||
|
|
@ -66,9 +66,9 @@ async def test_hybrid_node_search_only_fulltext():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
mock_fulltext_search.return_value = [
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
|
||||
|
|
@ -90,9 +90,9 @@ async def test_hybrid_node_search_with_limit():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
mock_fulltext_search.return_value = [
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
||||
|
|
@ -120,8 +120,8 @@ async def test_hybrid_node_search_with_limit():
|
|||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
# Verify that the limit was passed to the search functions
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2)
|
||||
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2)
|
||||
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -129,9 +129,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
mock_fulltext_search.return_value = [
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
||||
|
|
@ -155,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
|||
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
||||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)
|
||||
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4)
|
||||
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue