Search refactor + Community search (#111)
* WIP * WIP * WIP * community search * WIP * WIP * integration tested * tests * tests * mypy * mypy * format
This commit is contained in:
parent
e4ee8d62fa
commit
d7c20c1f59
13 changed files with 780 additions and 329 deletions
|
|
@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
|
|
||||||
if not use_bulk:
|
if not use_bulk:
|
||||||
for i, message in enumerate(messages[3:130]):
|
for i, message in enumerate(messages[3:20]):
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,20 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphitiError(Exception):
|
class GraphitiError(Exception):
|
||||||
"""Base exception class for Graphiti Core."""
|
"""Base exception class for Graphiti Core."""
|
||||||
|
|
||||||
|
|
@ -16,3 +33,11 @@ class NodeNotFoundError(GraphitiError):
|
||||||
def __init__(self, uuid: str):
|
def __init__(self, uuid: str):
|
||||||
self.message = f'node {uuid} not found'
|
self.message = f'node {uuid} not found'
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchRerankerError(GraphitiError):
|
||||||
|
"""Raised when a node is not found."""
|
||||||
|
|
||||||
|
def __init__(self, text: str):
|
||||||
|
self.message = text
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
|
||||||
|
|
@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.llm_client.utils import generate_embedding
|
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
|
from graphiti_core.search.search import SearchConfig, search
|
||||||
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
||||||
|
from graphiti_core.search.search_config_recipes import (
|
||||||
|
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
|
||||||
|
EDGE_HYBRID_SEARCH_RRF,
|
||||||
|
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
||||||
|
NODE_HYBRID_SEARCH_RRF,
|
||||||
|
)
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
get_relevant_nodes,
|
get_relevant_nodes,
|
||||||
hybrid_node_search,
|
|
||||||
)
|
)
|
||||||
from graphiti_core.utils import (
|
from graphiti_core.utils import (
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
|
|
@ -548,7 +553,7 @@ class Graphiti:
|
||||||
query: str,
|
query: str,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str | None] | None = None,
|
||||||
num_results=10,
|
num_results=DEFAULT_SEARCH_LIMIT,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search on the knowledge graph.
|
Perform a hybrid search on the knowledge graph.
|
||||||
|
|
@ -564,7 +569,7 @@ class Graphiti:
|
||||||
Facts will be reranked based on proximity to this node
|
Facts will be reranked based on proximity to this node
|
||||||
group_ids : list[str | None] | None, optional
|
group_ids : list[str | None] | None, optional
|
||||||
The graph partitions to return data from.
|
The graph partitions to return data from.
|
||||||
num_results : int, optional
|
limit : int, optional
|
||||||
The maximum number of results to return. Defaults to 10.
|
The maximum number of results to return. Defaults to 10.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
|
@ -581,21 +586,17 @@ class Graphiti:
|
||||||
The search is performed using the current date and time as the reference
|
The search is performed using the current date and time as the reference
|
||||||
point for temporal relevance.
|
point for temporal relevance.
|
||||||
"""
|
"""
|
||||||
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
|
search_config = (
|
||||||
search_config = SearchConfig(
|
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
||||||
num_episodes=0,
|
|
||||||
num_edges=num_results,
|
|
||||||
num_nodes=0,
|
|
||||||
group_ids=group_ids,
|
|
||||||
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
|
||||||
reranker=reranker,
|
|
||||||
)
|
)
|
||||||
|
search_config.limit = num_results
|
||||||
|
|
||||||
edges = (
|
edges = (
|
||||||
await hybrid_search(
|
await search(
|
||||||
self.driver,
|
self.driver,
|
||||||
self.llm_client.get_embedder(),
|
self.llm_client.get_embedder(),
|
||||||
query,
|
query,
|
||||||
datetime.now(),
|
group_ids,
|
||||||
search_config,
|
search_config,
|
||||||
center_node_uuid,
|
center_node_uuid,
|
||||||
)
|
)
|
||||||
|
|
@ -606,19 +607,20 @@ class Graphiti:
|
||||||
async def _search(
|
async def _search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
timestamp: datetime,
|
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
|
group_ids: list[str | None] | None = None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
):
|
) -> SearchResults:
|
||||||
return await hybrid_search(
|
return await search(
|
||||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_nodes_by_query(
|
async def get_nodes_by_query(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
center_node_uuid: str | None = None,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str | None] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve nodes from the graph database based on a text query.
|
Retrieve nodes from the graph database based on a text query.
|
||||||
|
|
@ -629,7 +631,9 @@ class Graphiti:
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query : str
|
query : str
|
||||||
The text query to search for in the graph.
|
The text query to search for in the graph
|
||||||
|
center_node_uuid: str, optional
|
||||||
|
Facts will be reranked based on proximity to this node.
|
||||||
group_ids : list[str | None] | None, optional
|
group_ids : list[str | None] | None, optional
|
||||||
The graph partitions to return data from.
|
The graph partitions to return data from.
|
||||||
limit : int | None, optional
|
limit : int | None, optional
|
||||||
|
|
@ -655,8 +659,12 @@ class Graphiti:
|
||||||
If not specified, a default limit (defined in the search functions) will be used.
|
If not specified, a default limit (defined in the search functions) will be used.
|
||||||
"""
|
"""
|
||||||
embedder = self.llm_client.get_embedder()
|
embedder = self.llm_client.get_embedder()
|
||||||
query_embedding = await generate_embedding(embedder, query)
|
search_config = (
|
||||||
relevant_nodes = await hybrid_node_search(
|
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
|
||||||
[query], [query_embedding], self.driver, group_ids, limit
|
|
||||||
)
|
)
|
||||||
return relevant_nodes
|
search_config.limit = limit
|
||||||
|
|
||||||
|
nodes = (
|
||||||
|
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
||||||
|
).nodes
|
||||||
|
return nodes
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from neo4j import time as neo4j_time
|
from neo4j import time as neo4j_time
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,20 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RateLimitError(Exception):
|
class RateLimitError(Exception):
|
||||||
"""Exception raised when the rate limit is exceeded."""
|
"""Exception raised when the rate limit is exceeded."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
from time import time
|
from time import time
|
||||||
|
|
@ -17,6 +33,6 @@ async def generate_embedding(
|
||||||
embedding = embedding[:EMBEDDING_DIM]
|
embedding = embedding[:EMBEDDING_DIM]
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.debug(f'embedded text of length {len(text)} in {end-start} ms')
|
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Any, Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
|
||||||
|
|
@ -15,131 +15,227 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
|
from graphiti_core.errors import SearchRerankerError
|
||||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||||
|
from graphiti_core.search.search_config import (
|
||||||
|
DEFAULT_SEARCH_LIMIT,
|
||||||
|
CommunityReranker,
|
||||||
|
CommunitySearchConfig,
|
||||||
|
CommunitySearchMethod,
|
||||||
|
EdgeReranker,
|
||||||
|
EdgeSearchConfig,
|
||||||
|
EdgeSearchMethod,
|
||||||
|
NodeReranker,
|
||||||
|
NodeSearchConfig,
|
||||||
|
NodeSearchMethod,
|
||||||
|
SearchConfig,
|
||||||
|
SearchResults,
|
||||||
|
)
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
|
community_fulltext_search,
|
||||||
|
community_similarity_search,
|
||||||
edge_fulltext_search,
|
edge_fulltext_search,
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
get_mentioned_nodes,
|
|
||||||
node_distance_reranker,
|
node_distance_reranker,
|
||||||
|
node_fulltext_search,
|
||||||
|
node_similarity_search,
|
||||||
rrf,
|
rrf,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils import retrieve_episodes
|
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SearchMethod(Enum):
|
async def search(
|
||||||
cosine_similarity = 'cosine_similarity'
|
|
||||||
bm25 = 'bm25'
|
|
||||||
|
|
||||||
|
|
||||||
class Reranker(Enum):
|
|
||||||
rrf = 'reciprocal_rank_fusion'
|
|
||||||
node_distance = 'node_distance'
|
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(BaseModel):
|
|
||||||
num_edges: int = Field(default=10)
|
|
||||||
num_nodes: int = Field(default=10)
|
|
||||||
num_episodes: int = EPISODE_WINDOW_LEN
|
|
||||||
group_ids: list[str | None] | None
|
|
||||||
search_methods: list[SearchMethod]
|
|
||||||
reranker: Reranker | None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResults(BaseModel):
|
|
||||||
episodes: list[EpisodicNode]
|
|
||||||
nodes: list[EntityNode]
|
|
||||||
edges: list[EntityEdge]
|
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_search(
|
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder,
|
embedder,
|
||||||
query: str,
|
query: str,
|
||||||
timestamp: datetime,
|
group_ids: list[str | None] | None,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
|
query = query.replace('\n', ' ')
|
||||||
|
|
||||||
episodes = []
|
edges = (
|
||||||
nodes = []
|
await edge_search(
|
||||||
edges = []
|
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
|
||||||
|
|
||||||
search_results = []
|
|
||||||
|
|
||||||
if config.num_episodes > 0:
|
|
||||||
episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
|
|
||||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
|
||||||
|
|
||||||
if SearchMethod.bm25 in config.search_methods:
|
|
||||||
text_search = await edge_fulltext_search(
|
|
||||||
driver, query, None, None, config.group_ids, 2 * config.num_edges
|
|
||||||
)
|
)
|
||||||
search_results.append(text_search)
|
if config.edge_config is not None
|
||||||
|
else []
|
||||||
if SearchMethod.cosine_similarity in config.search_methods:
|
)
|
||||||
query_text = query.replace('\n', ' ')
|
nodes = (
|
||||||
search_vector = (
|
await node_search(
|
||||||
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
|
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
|
||||||
.data[0]
|
|
||||||
.embedding[:EMBEDDING_DIM]
|
|
||||||
)
|
)
|
||||||
|
if config.node_config is not None
|
||||||
similarity_search = await edge_similarity_search(
|
else []
|
||||||
driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
|
)
|
||||||
|
communities = (
|
||||||
|
await community_search(
|
||||||
|
driver, embedder, query, group_ids, config.community_config, config.limit
|
||||||
)
|
)
|
||||||
search_results.append(similarity_search)
|
if config.community_config is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
if len(search_results) > 1 and config.reranker is None:
|
results = SearchResults(
|
||||||
logger.exception('Multiple searches enabled without a reranker')
|
edges=edges[: config.limit],
|
||||||
raise Exception('Multiple searches enabled without a reranker')
|
nodes=nodes[: config.limit],
|
||||||
|
communities=communities[: config.limit],
|
||||||
else:
|
|
||||||
edge_uuid_map = {}
|
|
||||||
search_result_uuids = []
|
|
||||||
|
|
||||||
for result in search_results:
|
|
||||||
result_uuids = []
|
|
||||||
for edge in result:
|
|
||||||
result_uuids.append(edge.uuid)
|
|
||||||
edge_uuid_map[edge.uuid] = edge
|
|
||||||
|
|
||||||
search_result_uuids.append(result_uuids)
|
|
||||||
|
|
||||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
||||||
|
|
||||||
reranked_uuids: list[str] = []
|
|
||||||
if config.reranker == Reranker.rrf:
|
|
||||||
reranked_uuids = rrf(search_result_uuids)
|
|
||||||
elif config.reranker == Reranker.node_distance:
|
|
||||||
if center_node_uuid is None:
|
|
||||||
logger.exception('No center node provided for Node Distance reranker')
|
|
||||||
raise Exception('No center node provided for Node Distance reranker')
|
|
||||||
reranked_uuids = await node_distance_reranker(
|
|
||||||
driver, search_result_uuids, center_node_uuid
|
|
||||||
)
|
|
||||||
|
|
||||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
|
||||||
edges.extend(reranked_edges)
|
|
||||||
|
|
||||||
context = SearchResults(
|
|
||||||
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
|
|
||||||
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
return context
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def edge_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
embedder,
|
||||||
|
query: str,
|
||||||
|
group_ids: list[str | None] | None,
|
||||||
|
config: EdgeSearchConfig,
|
||||||
|
center_node_uuid: str | None = None,
|
||||||
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
search_results: list[list[EntityEdge]] = []
|
||||||
|
|
||||||
|
if EdgeSearchMethod.bm25 in config.search_methods:
|
||||||
|
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
|
||||||
|
search_results.append(text_search)
|
||||||
|
|
||||||
|
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
|
search_vector = (
|
||||||
|
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||||
|
.data[0]
|
||||||
|
.embedding[:EMBEDDING_DIM]
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_search = await edge_similarity_search(
|
||||||
|
driver, search_vector, None, None, group_ids, 2 * limit
|
||||||
|
)
|
||||||
|
search_results.append(similarity_search)
|
||||||
|
|
||||||
|
if len(search_results) > 1 and config.reranker is None:
|
||||||
|
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
|
||||||
|
|
||||||
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||||
|
|
||||||
|
reranked_uuids: list[str] = []
|
||||||
|
if config.reranker == EdgeReranker.rrf:
|
||||||
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||||
|
|
||||||
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
elif config.reranker == EdgeReranker.node_distance:
|
||||||
|
if center_node_uuid is None:
|
||||||
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||||
|
|
||||||
|
source_to_edge_uuid_map = {
|
||||||
|
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
|
||||||
|
}
|
||||||
|
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
|
||||||
|
|
||||||
|
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
||||||
|
|
||||||
|
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
|
||||||
|
|
||||||
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
return reranked_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def node_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
embedder,
|
||||||
|
query: str,
|
||||||
|
group_ids: list[str | None] | None,
|
||||||
|
config: NodeSearchConfig,
|
||||||
|
center_node_uuid: str | None = None,
|
||||||
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
) -> list[EntityNode]:
|
||||||
|
search_results: list[list[EntityNode]] = []
|
||||||
|
|
||||||
|
if NodeSearchMethod.bm25 in config.search_methods:
|
||||||
|
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
|
||||||
|
search_results.append(text_search)
|
||||||
|
|
||||||
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
|
search_vector = (
|
||||||
|
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||||
|
.data[0]
|
||||||
|
.embedding[:EMBEDDING_DIM]
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_search = await node_similarity_search(
|
||||||
|
driver, search_vector, group_ids, 2 * limit
|
||||||
|
)
|
||||||
|
search_results.append(similarity_search)
|
||||||
|
|
||||||
|
if len(search_results) > 1 and config.reranker is None:
|
||||||
|
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
||||||
|
|
||||||
|
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
||||||
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
||||||
|
|
||||||
|
reranked_uuids: list[str] = []
|
||||||
|
if config.reranker == NodeReranker.rrf:
|
||||||
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
elif config.reranker == NodeReranker.node_distance:
|
||||||
|
if center_node_uuid is None:
|
||||||
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||||
|
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
|
||||||
|
|
||||||
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
return reranked_nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def community_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
embedder,
|
||||||
|
query: str,
|
||||||
|
group_ids: list[str | None] | None,
|
||||||
|
config: CommunitySearchConfig,
|
||||||
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
) -> list[CommunityNode]:
|
||||||
|
search_results: list[list[CommunityNode]] = []
|
||||||
|
|
||||||
|
if CommunitySearchMethod.bm25 in config.search_methods:
|
||||||
|
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
|
||||||
|
search_results.append(text_search)
|
||||||
|
|
||||||
|
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
||||||
|
search_vector = (
|
||||||
|
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
||||||
|
.data[0]
|
||||||
|
.embedding[:EMBEDDING_DIM]
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_search = await community_similarity_search(
|
||||||
|
driver, search_vector, group_ids, 2 * limit
|
||||||
|
)
|
||||||
|
search_results.append(similarity_search)
|
||||||
|
|
||||||
|
if len(search_results) > 1 and config.reranker is None:
|
||||||
|
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
||||||
|
|
||||||
|
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
||||||
|
community_uuid_map = {
|
||||||
|
community.uuid: community for result in search_results for community in result
|
||||||
|
}
|
||||||
|
|
||||||
|
reranked_uuids: list[str] = []
|
||||||
|
if config.reranker == CommunityReranker.rrf:
|
||||||
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
|
||||||
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
return reranked_communities
|
||||||
|
|
|
||||||
81
graphiti_core/search/search_config.py
Normal file
81
graphiti_core/search/search_config.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from graphiti_core.edges import EntityEdge
|
||||||
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||||
|
|
||||||
|
DEFAULT_SEARCH_LIMIT = 10
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeSearchMethod(Enum):
|
||||||
|
cosine_similarity = 'cosine_similarity'
|
||||||
|
bm25 = 'bm25'
|
||||||
|
|
||||||
|
|
||||||
|
class NodeSearchMethod(Enum):
|
||||||
|
cosine_similarity = 'cosine_similarity'
|
||||||
|
bm25 = 'bm25'
|
||||||
|
|
||||||
|
|
||||||
|
class CommunitySearchMethod(Enum):
|
||||||
|
cosine_similarity = 'cosine_similarity'
|
||||||
|
bm25 = 'bm25'
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeReranker(Enum):
|
||||||
|
rrf = 'reciprocal_rank_fusion'
|
||||||
|
node_distance = 'node_distance'
|
||||||
|
|
||||||
|
|
||||||
|
class NodeReranker(Enum):
|
||||||
|
rrf = 'reciprocal_rank_fusion'
|
||||||
|
node_distance = 'node_distance'
|
||||||
|
|
||||||
|
|
||||||
|
class CommunityReranker(Enum):
|
||||||
|
rrf = 'reciprocal_rank_fusion'
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeSearchConfig(BaseModel):
|
||||||
|
search_methods: list[EdgeSearchMethod]
|
||||||
|
reranker: EdgeReranker | None
|
||||||
|
|
||||||
|
|
||||||
|
class NodeSearchConfig(BaseModel):
|
||||||
|
search_methods: list[NodeSearchMethod]
|
||||||
|
reranker: NodeReranker | None
|
||||||
|
|
||||||
|
|
||||||
|
class CommunitySearchConfig(BaseModel):
|
||||||
|
search_methods: list[CommunitySearchMethod]
|
||||||
|
reranker: CommunityReranker | None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchConfig(BaseModel):
|
||||||
|
edge_config: EdgeSearchConfig | None = Field(default=None)
|
||||||
|
node_config: NodeSearchConfig | None = Field(default=None)
|
||||||
|
community_config: CommunitySearchConfig | None = Field(default=None)
|
||||||
|
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResults(BaseModel):
|
||||||
|
edges: list[EntityEdge]
|
||||||
|
nodes: list[EntityNode]
|
||||||
|
communities: list[CommunityNode]
|
||||||
84
graphiti_core/search/search_config_recipes.py
Normal file
84
graphiti_core/search/search_config_recipes.py
Normal file
|
|
@ -0,0 +1,84 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from graphiti_core.search.search_config import (
|
||||||
|
CommunityReranker,
|
||||||
|
CommunitySearchConfig,
|
||||||
|
CommunitySearchMethod,
|
||||||
|
EdgeReranker,
|
||||||
|
EdgeSearchConfig,
|
||||||
|
EdgeSearchMethod,
|
||||||
|
NodeReranker,
|
||||||
|
NodeSearchConfig,
|
||||||
|
NodeSearchMethod,
|
||||||
|
SearchConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performs a hybrid search with rrf reranking over edges, nodes, and communities
|
||||||
|
COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
|
edge_config=EdgeSearchConfig(
|
||||||
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
|
reranker=EdgeReranker.rrf,
|
||||||
|
),
|
||||||
|
node_config=NodeSearchConfig(
|
||||||
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||||
|
reranker=NodeReranker.rrf,
|
||||||
|
),
|
||||||
|
community_config=CommunitySearchConfig(
|
||||||
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||||
|
reranker=CommunityReranker.rrf,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over edges with rrf reranking
|
||||||
|
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
|
edge_config=EdgeSearchConfig(
|
||||||
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
|
reranker=EdgeReranker.rrf,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over edges with node distance reranking
|
||||||
|
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
|
edge_config=EdgeSearchConfig(
|
||||||
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
|
reranker=EdgeReranker.node_distance,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over nodes with rrf reranking
|
||||||
|
NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
|
node_config=NodeSearchConfig(
|
||||||
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||||
|
reranker=NodeReranker.rrf,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over nodes with node distance reranking
|
||||||
|
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
|
node_config=NodeSearchConfig(
|
||||||
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||||
|
reranker=NodeReranker.node_distance,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over communities with rrf reranking
|
||||||
|
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
|
community_config=CommunitySearchConfig(
|
||||||
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||||
|
reranker=CommunityReranker.rrf,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -7,7 +23,13 @@ from time import time
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
|
from graphiti_core.nodes import (
|
||||||
|
CommunityNode,
|
||||||
|
EntityNode,
|
||||||
|
EpisodicNode,
|
||||||
|
get_community_node_from_record,
|
||||||
|
get_entity_node_from_record,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -35,181 +57,6 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
|
||||||
driver: AsyncDriver,
|
|
||||||
search_vector: list[float],
|
|
||||||
source_node_uuid: str | None,
|
|
||||||
target_node_uuid: str | None,
|
|
||||||
group_ids: list[str | None] | None = None,
|
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
group_ids = group_ids if group_ids is not None else [None]
|
|
||||||
# vector similarity search over embedded facts
|
|
||||||
query = Query("""
|
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
||||||
YIELD relationship AS rel, score
|
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
||||||
WHERE r.group_id IN $group_ids
|
|
||||||
RETURN
|
|
||||||
r.uuid AS uuid,
|
|
||||||
r.group_id AS group_id,
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at
|
|
||||||
ORDER BY score DESC
|
|
||||||
""")
|
|
||||||
|
|
||||||
if source_node_uuid is None and target_node_uuid is None:
|
|
||||||
query = Query("""
|
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
||||||
YIELD relationship AS rel, score
|
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
||||||
WHERE r.group_id IN $group_ids
|
|
||||||
RETURN
|
|
||||||
r.uuid AS uuid,
|
|
||||||
r.group_id AS group_id,
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at
|
|
||||||
ORDER BY score DESC
|
|
||||||
""")
|
|
||||||
elif source_node_uuid is None:
|
|
||||||
query = Query("""
|
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
||||||
YIELD relationship AS rel, score
|
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
||||||
WHERE r.group_id IN $group_ids
|
|
||||||
RETURN
|
|
||||||
r.uuid AS uuid,
|
|
||||||
r.group_id AS group_id,
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at
|
|
||||||
ORDER BY score DESC
|
|
||||||
""")
|
|
||||||
elif target_node_uuid is None:
|
|
||||||
query = Query("""
|
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
||||||
YIELD relationship AS rel, score
|
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
||||||
WHERE r.group_id IN $group_ids
|
|
||||||
RETURN
|
|
||||||
r.uuid AS uuid,
|
|
||||||
r.group_id AS group_id,
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at
|
|
||||||
ORDER BY score DESC
|
|
||||||
""")
|
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
search_vector=search_vector,
|
|
||||||
source_uuid=source_node_uuid,
|
|
||||||
target_uuid=target_node_uuid,
|
|
||||||
group_ids=group_ids,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
|
|
||||||
async def entity_similarity_search(
|
|
||||||
search_vector: list[float],
|
|
||||||
driver: AsyncDriver,
|
|
||||||
group_ids: list[str | None] | None = None,
|
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
|
||||||
) -> list[EntityNode]:
|
|
||||||
group_ids = group_ids if group_ids is not None else [None]
|
|
||||||
|
|
||||||
# vector similarity search over entity names
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
|
||||||
YIELD node AS n, score
|
|
||||||
MATCH (n WHERE n.group_id IN $group_ids)
|
|
||||||
RETURN
|
|
||||||
n.uuid As uuid,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.name AS name,
|
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary
|
|
||||||
ORDER BY score DESC
|
|
||||||
""",
|
|
||||||
search_vector=search_vector,
|
|
||||||
group_ids=group_ids,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
|
||||||
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def entity_fulltext_search(
|
|
||||||
query: str,
|
|
||||||
driver: AsyncDriver,
|
|
||||||
group_ids: list[str | None] | None = None,
|
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
|
||||||
) -> list[EntityNode]:
|
|
||||||
group_ids = group_ids if group_ids is not None else [None]
|
|
||||||
|
|
||||||
# BM25 search to get top nodes
|
|
||||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
|
||||||
YIELD node AS n, score
|
|
||||||
MATCH (n WHERE n.group_id in $group_ids)
|
|
||||||
RETURN
|
|
||||||
n.uuid AS uuid,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.name AS name,
|
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
""",
|
|
||||||
query=fuzzy_query,
|
|
||||||
group_ids=group_ids,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
|
||||||
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -322,6 +169,247 @@ async def edge_fulltext_search(
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
async def edge_similarity_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
search_vector: list[float],
|
||||||
|
source_node_uuid: str | None,
|
||||||
|
target_node_uuid: str | None,
|
||||||
|
group_ids: list[str | None] | None = None,
|
||||||
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
group_ids = group_ids if group_ids is not None else [None]
|
||||||
|
# vector similarity search over embedded facts
|
||||||
|
query = Query("""
|
||||||
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
|
WHERE r.group_id IN $group_ids
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
r.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
if source_node_uuid is None and target_node_uuid is None:
|
||||||
|
query = Query("""
|
||||||
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
|
WHERE r.group_id IN $group_ids
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
r.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC
|
||||||
|
""")
|
||||||
|
elif source_node_uuid is None:
|
||||||
|
query = Query("""
|
||||||
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
|
WHERE r.group_id IN $group_ids
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
r.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC
|
||||||
|
""")
|
||||||
|
elif target_node_uuid is None:
|
||||||
|
query = Query("""
|
||||||
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
|
WHERE r.group_id IN $group_ids
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
r.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
search_vector=search_vector,
|
||||||
|
source_uuid=source_node_uuid,
|
||||||
|
target_uuid=target_node_uuid,
|
||||||
|
group_ids=group_ids,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
async def node_fulltext_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
query: str,
|
||||||
|
group_ids: list[str | None] | None = None,
|
||||||
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[EntityNode]:
|
||||||
|
group_ids = group_ids if group_ids is not None else [None]
|
||||||
|
|
||||||
|
# BM25 search to get top nodes
|
||||||
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
||||||
|
YIELD node AS n, score
|
||||||
|
MATCH (n WHERE n.group_id in $group_ids)
|
||||||
|
RETURN
|
||||||
|
n.uuid AS uuid,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
""",
|
||||||
|
query=fuzzy_query,
|
||||||
|
group_ids=group_ids,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def node_similarity_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
search_vector: list[float],
|
||||||
|
group_ids: list[str | None] | None = None,
|
||||||
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[EntityNode]:
|
||||||
|
group_ids = group_ids if group_ids is not None else [None]
|
||||||
|
|
||||||
|
# vector similarity search over entity names
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||||
|
YIELD node AS n, score
|
||||||
|
MATCH (n WHERE n.group_id IN $group_ids)
|
||||||
|
RETURN
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary
|
||||||
|
ORDER BY score DESC
|
||||||
|
""",
|
||||||
|
search_vector=search_vector,
|
||||||
|
group_ids=group_ids,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def community_fulltext_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
query: str,
|
||||||
|
group_ids: list[str | None] | None = None,
|
||||||
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[CommunityNode]:
|
||||||
|
group_ids = group_ids if group_ids is not None else [None]
|
||||||
|
|
||||||
|
# BM25 search to get top communities
|
||||||
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||||
|
YIELD node AS comm, score
|
||||||
|
MATCH (comm WHERE comm.group_id in $group_ids)
|
||||||
|
RETURN
|
||||||
|
comm.uuid AS uuid,
|
||||||
|
comm.group_id AS group_id,
|
||||||
|
comm.name AS name,
|
||||||
|
comm.name_embedding AS name_embedding,
|
||||||
|
comm.created_at AS created_at,
|
||||||
|
comm.summary AS summary
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
""",
|
||||||
|
query=fuzzy_query,
|
||||||
|
group_ids=group_ids,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
return communities
|
||||||
|
|
||||||
|
|
||||||
|
async def community_similarity_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
search_vector: list[float],
|
||||||
|
group_ids: list[str | None] | None = None,
|
||||||
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[CommunityNode]:
|
||||||
|
group_ids = group_ids if group_ids is not None else [None]
|
||||||
|
|
||||||
|
# vector similarity search over entity names
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
||||||
|
YIELD node AS comm, score
|
||||||
|
MATCH (comm WHERE comm.group_id IN $group_ids)
|
||||||
|
RETURN
|
||||||
|
comm.uuid As uuid,
|
||||||
|
comm.group_id AS group_id,
|
||||||
|
comm.name AS name,
|
||||||
|
comm.name_embedding AS name_embedding,
|
||||||
|
comm.created_at AS created_at,
|
||||||
|
comm.summary AS summary
|
||||||
|
ORDER BY score DESC
|
||||||
|
""",
|
||||||
|
search_vector=search_vector,
|
||||||
|
group_ids=group_ids,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
return communities
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_node_search(
|
async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
|
|
@ -371,8 +459,8 @@ async def hybrid_node_search(
|
||||||
|
|
||||||
results: list[list[EntityNode]] = list(
|
results: list[list[EntityNode]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
|
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
||||||
*[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
|
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -490,24 +578,23 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
async def node_distance_reranker(
|
async def node_distance_reranker(
|
||||||
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# use rrf as a preliminary ranker
|
# use rrf as a preliminary ranker
|
||||||
sorted_uuids = rrf(results)
|
sorted_uuids = rrf(node_uuids)
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = Query("""
|
query = Query("""
|
||||||
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
|
||||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
|
RETURN length(p) AS score
|
||||||
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
|
||||||
""")
|
""")
|
||||||
|
|
||||||
path_results = await asyncio.gather(
|
path_results = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
driver.execute_query(
|
driver.execute_query(
|
||||||
query,
|
query,
|
||||||
edge_uuid=uuid,
|
node_uuid=uuid,
|
||||||
center_uuid=center_node_uuid,
|
center_uuid=center_node_uuid,
|
||||||
)
|
)
|
||||||
for uuid in sorted_uuids
|
for uuid in sorted_uuids
|
||||||
|
|
@ -518,15 +605,8 @@ async def node_distance_reranker(
|
||||||
records = result[0]
|
records = result[0]
|
||||||
record = records[0] if len(records) > 0 else None
|
record = records[0] if len(records) > 0 else None
|
||||||
distance: float = record['score'] if record is not None else float('inf')
|
distance: float = record['score'] if record is not None else float('inf')
|
||||||
if record is not None and (
|
distance = 0 if uuid == center_node_uuid else distance
|
||||||
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
|
scores[uuid] = distance
|
||||||
):
|
|
||||||
distance = 0
|
|
||||||
|
|
||||||
if uuid in scores:
|
|
||||||
scores[uuid] = min(distance, scores[uuid])
|
|
||||||
else:
|
|
||||||
scores[uuid] = distance
|
|
||||||
|
|
||||||
# rerank on shortest distance
|
# rerank on shortest distance
|
||||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from dotenv import load_dotenv
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.graphiti import Graphiti
|
from graphiti_core.graphiti import Graphiti
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||||
|
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
@ -81,6 +82,17 @@ async def test_graphiti_init():
|
||||||
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
|
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
|
||||||
|
|
||||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
||||||
|
|
||||||
|
results = await graphiti._search(
|
||||||
|
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1']
|
||||||
|
)
|
||||||
|
pretty_results = {
|
||||||
|
'edges': [edge.fact for edge in results.edges],
|
||||||
|
'nodes': [node.name for node in results.nodes],
|
||||||
|
'communities': [community.name for community in results.communities],
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(pretty_results)
|
||||||
graphiti.close()
|
graphiti.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,11 @@ async def test_hybrid_node_search_deduplication():
|
||||||
# Mock the database driver
|
# Mock the database driver
|
||||||
mock_driver = AsyncMock()
|
mock_driver = AsyncMock()
|
||||||
|
|
||||||
# Mock the entity_fulltext_search and entity_similarity_search functions
|
# Mock the node_fulltext_search and entity_similarity_search functions
|
||||||
with patch(
|
with patch(
|
||||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||||
) as mock_fulltext_search, patch(
|
) as mock_fulltext_search, patch(
|
||||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
'graphiti_core.search.search_utils.node_similarity_search'
|
||||||
) as mock_similarity_search:
|
) as mock_similarity_search:
|
||||||
# Set up mock return values
|
# Set up mock return values
|
||||||
mock_fulltext_search.side_effect = [
|
mock_fulltext_search.side_effect = [
|
||||||
|
|
@ -47,9 +47,9 @@ async def test_hybrid_node_search_empty_results():
|
||||||
mock_driver = AsyncMock()
|
mock_driver = AsyncMock()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||||
) as mock_fulltext_search, patch(
|
) as mock_fulltext_search, patch(
|
||||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
'graphiti_core.search.search_utils.node_similarity_search'
|
||||||
) as mock_similarity_search:
|
) as mock_similarity_search:
|
||||||
mock_fulltext_search.return_value = []
|
mock_fulltext_search.return_value = []
|
||||||
mock_similarity_search.return_value = []
|
mock_similarity_search.return_value = []
|
||||||
|
|
@ -66,9 +66,9 @@ async def test_hybrid_node_search_only_fulltext():
|
||||||
mock_driver = AsyncMock()
|
mock_driver = AsyncMock()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||||
) as mock_fulltext_search, patch(
|
) as mock_fulltext_search, patch(
|
||||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
'graphiti_core.search.search_utils.node_similarity_search'
|
||||||
) as mock_similarity_search:
|
) as mock_similarity_search:
|
||||||
mock_fulltext_search.return_value = [
|
mock_fulltext_search.return_value = [
|
||||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
|
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
|
||||||
|
|
@ -90,9 +90,9 @@ async def test_hybrid_node_search_with_limit():
|
||||||
mock_driver = AsyncMock()
|
mock_driver = AsyncMock()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||||
) as mock_fulltext_search, patch(
|
) as mock_fulltext_search, patch(
|
||||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
'graphiti_core.search.search_utils.node_similarity_search'
|
||||||
) as mock_similarity_search:
|
) as mock_similarity_search:
|
||||||
mock_fulltext_search.return_value = [
|
mock_fulltext_search.return_value = [
|
||||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
||||||
|
|
@ -120,8 +120,8 @@ async def test_hybrid_node_search_with_limit():
|
||||||
assert mock_fulltext_search.call_count == 1
|
assert mock_fulltext_search.call_count == 1
|
||||||
assert mock_similarity_search.call_count == 1
|
assert mock_similarity_search.call_count == 1
|
||||||
# Verify that the limit was passed to the search functions
|
# Verify that the limit was passed to the search functions
|
||||||
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2)
|
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2)
|
||||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2)
|
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -129,9 +129,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
||||||
mock_driver = AsyncMock()
|
mock_driver = AsyncMock()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
'graphiti_core.search.search_utils.entity_fulltext_search'
|
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||||
) as mock_fulltext_search, patch(
|
) as mock_fulltext_search, patch(
|
||||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
'graphiti_core.search.search_utils.node_similarity_search'
|
||||||
) as mock_similarity_search:
|
) as mock_similarity_search:
|
||||||
mock_fulltext_search.return_value = [
|
mock_fulltext_search.return_value = [
|
||||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
||||||
|
|
@ -155,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
||||||
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
||||||
assert mock_fulltext_search.call_count == 1
|
assert mock_fulltext_search.call_count == 1
|
||||||
assert mock_similarity_search.call_count == 1
|
assert mock_similarity_search.call_count == 1
|
||||||
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4)
|
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4)
|
||||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)
|
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue