node label filters (#265)

* node label filters

* update

* add search filters

* updates

* bump versions

* update tests

* test update
This commit is contained in:
Preston Rasmussen 2025-02-21 12:38:01 -05:00 committed by GitHub
parent 29a071b2b8
commit 088029a80c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 120 additions and 47 deletions

View file

@ -351,7 +351,10 @@ class Graphiti:
# Find relevant nodes already in the graph # Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list( existing_nodes_lists: list[list[EntityNode]] = list(
await semaphore_gather( await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes] *[
get_relevant_nodes(self.driver, SearchFilters(), [node])
for node in extracted_nodes
]
) )
) )
@ -732,8 +735,8 @@ class Graphiti:
self.llm_client, self.llm_client,
[source_node, target_node], [source_node, target_node],
[ [
await get_relevant_nodes(self.driver, [source_node]), await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
await get_relevant_nodes(self.driver, [target_node]), await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
], ],
) )

View file

@ -100,6 +100,7 @@ async def search(
query_vector, query_vector,
group_ids, group_ids,
config.node_config, config.node_config,
search_filter,
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids, bfs_origin_node_uuids,
config.limit, config.limit,
@ -233,6 +234,7 @@ async def node_search(
query_vector: list[float], query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: NodeSearchConfig | None, config: NodeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
@ -243,11 +245,13 @@ async def node_search(
search_results: list[list[EntityNode]] = list( search_results: list[list[EntityNode]] = list(
await semaphore_gather( await semaphore_gather(
*[ *[
node_fulltext_search(driver, query, group_ids, 2 * limit), node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
node_similarity_search( node_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
), ),
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
] ]
) )
) )
@ -255,7 +259,9 @@ async def node_search(
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result] origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append( search_results.append(
await node_bfs_search(driver, origin_node_uuids, config.bfs_max_depth, 2 * limit) await node_bfs_search(
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
)
) )
search_result_uuids = [[node.uuid for node in result] for result in search_results] search_result_uuids = [[node.uuid for node in result] for result in search_results]

View file

@ -39,18 +39,37 @@ class DateFilter(BaseModel):
class SearchFilters(BaseModel): class SearchFilters(BaseModel):
node_labels: list[str] | None = Field(
default=None, description='List of node labels to filter on'
)
valid_at: list[list[DateFilter]] | None = Field(default=None) valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None) invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None) created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None) expired_at: list[list[DateFilter]] | None = Field(default=None)
def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]: def node_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}
if filters.node_labels is not None:
node_labels = ':'.join(filters.node_labels)
node_label_filter = ' AND n:' + node_labels
filter_query += node_label_filter
return filter_query, filter_params
def edge_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = '' filter_query: LiteralString = ''
filter_params: dict[str, Any] = {} filter_params: dict[str, Any] = {}
if filters.valid_at is not None: if filters.valid_at is not None:
valid_at_filter = 'AND (' valid_at_filter = ' AND ('
for i, or_list in enumerate(filters.valid_at): for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
filter_params['valid_at_' + str(j)] = date_filter.date filter_params['valid_at_' + str(j)] = date_filter.date
@ -75,7 +94,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
filter_query += valid_at_filter filter_query += valid_at_filter
if filters.invalid_at is not None: if filters.invalid_at is not None:
invalid_at_filter = 'AND (' invalid_at_filter = ' AND ('
for i, or_list in enumerate(filters.invalid_at): for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
filter_params['invalid_at_' + str(j)] = date_filter.date filter_params['invalid_at_' + str(j)] = date_filter.date
@ -100,7 +119,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
filter_query += invalid_at_filter filter_query += invalid_at_filter
if filters.created_at is not None: if filters.created_at is not None:
created_at_filter = 'AND (' created_at_filter = ' AND ('
for i, or_list in enumerate(filters.created_at): for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
filter_params['created_at_' + str(j)] = date_filter.date filter_params['created_at_' + str(j)] = date_filter.date

View file

@ -38,7 +38,11 @@ from graphiti_core.nodes import (
get_community_node_from_record, get_community_node_from_record,
get_entity_node_from_record, get_entity_node_from_record,
) )
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor from graphiti_core.search.search_filters import (
SearchFilters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -148,7 +152,7 @@ async def edge_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
filter_query, filter_params = search_filter_query_constructor(search_filter) filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
cypher_query = Query( cypher_query = Query(
""" """
@ -207,7 +211,7 @@ async def edge_similarity_search(
query_params: dict[str, Any] = {} query_params: dict[str, Any] = {}
filter_query, filter_params = search_filter_query_constructor(search_filter) filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query_params.update(filter_params) query_params.update(filter_params)
group_filter_query: LiteralString = '' group_filter_query: LiteralString = ''
@ -225,8 +229,8 @@ async def edge_similarity_search(
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
""" """
+ group_filter_query + group_filter_query
+ filter_query + filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -278,7 +282,7 @@ async def edge_bfs_search(
if bfs_origin_node_uuids is None: if bfs_origin_node_uuids is None:
return [] return []
filter_query, filter_params = search_filter_query_constructor(search_filter) filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query = Query( query = Query(
""" """
@ -325,6 +329,7 @@ async def edge_bfs_search(
async def node_fulltext_search( async def node_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
@ -333,10 +338,17 @@ async def node_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score YIELD node AS node, score
MATCH (n:Entity)
WHERE n.uuid = node.uuid
"""
+ filter_query
+ """
RETURN RETURN
n.uuid AS uuid, n.uuid AS uuid,
n.group_id AS group_id, n.group_id AS group_id,
@ -349,6 +361,7 @@ async def node_fulltext_search(
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""", """,
filter_params,
query=fuzzy_query, query=fuzzy_query,
group_ids=group_ids, group_ids=group_ids,
limit=limit, limit=limit,
@ -363,6 +376,7 @@ async def node_fulltext_search(
async def node_similarity_search( async def node_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
search_filter: SearchFilters,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE, min_score: float = DEFAULT_MIN_SCORE,
@ -379,12 +393,16 @@ async def node_similarity_search(
group_filter_query += 'WHERE n.group_id IN $group_ids' group_filter_query += 'WHERE n.group_id IN $group_ids'
query_params['group_ids'] = group_ids query_params['group_ids'] = group_ids
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
runtime_query runtime_query
+ """ + """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ group_filter_query + group_filter_query
+ filter_query
+ """ + """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score WHERE score > $min_score
@ -416,6 +434,7 @@ async def node_similarity_search(
async def node_bfs_search( async def node_bfs_search(
driver: AsyncDriver, driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None, bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters,
bfs_max_depth: int, bfs_max_depth: int,
limit: int, limit: int,
) -> list[EntityNode]: ) -> list[EntityNode]:
@ -423,21 +442,28 @@ async def node_bfs_search(
if bfs_origin_node_uuids is None: if bfs_origin_node_uuids is None:
return [] return []
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
RETURN DISTINCT WHERE n.group_id = origin.group_id
n.uuid As uuid, """
n.group_id AS group_id, + filter_query
n.name AS name, + """
n.name_embedding AS name_embedding, RETURN DISTINCT
n.created_at AS created_at, n.uuid As uuid,
n.summary AS summary, n.group_id AS group_id,
labels(n) AS labels, n.name AS name,
properties(n) AS attributes n.name_embedding AS name_embedding,
LIMIT $limit n.created_at AS created_at,
""", n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids, bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth, depth=bfs_max_depth,
limit=limit, limit=limit,
@ -539,6 +565,7 @@ async def hybrid_node_search(
queries: list[str], queries: list[str],
embeddings: list[list[float]], embeddings: list[list[float]],
driver: AsyncDriver, driver: AsyncDriver,
search_filter: SearchFilters,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
@ -583,8 +610,14 @@ async def hybrid_node_search(
start = time() start = time()
results: list[list[EntityNode]] = list( results: list[list[EntityNode]] = list(
await semaphore_gather( await semaphore_gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries], *[
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings], node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
for q in queries
],
*[
node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
for e in embeddings
],
) )
) )
@ -604,6 +637,7 @@ async def hybrid_node_search(
async def get_relevant_nodes( async def get_relevant_nodes(
driver: AsyncDriver, driver: AsyncDriver,
search_filter: SearchFilters,
nodes: list[EntityNode], nodes: list[EntityNode],
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
@ -635,6 +669,7 @@ async def get_relevant_nodes(
[node.name for node in nodes], [node.name for node in nodes],
[node.name_embedding for node in nodes if node.name_embedding is not None], [node.name_embedding for node in nodes if node.name_embedding is not None],
driver, driver,
search_filter,
[node.group_id for node in nodes], [node.group_id for node in nodes],
) )

View file

@ -37,6 +37,7 @@ from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_SAVE_BULK, EPISODIC_NODE_SAVE_BULK,
) )
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
@ -188,7 +189,7 @@ async def dedupe_nodes_bulk(
existing_nodes_chunks: list[list[EntityNode]] = list( existing_nodes_chunks: list[list[EntityNode]] = list(
await semaphore_gather( await semaphore_gather(
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks] *[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
) )
) )

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "graphiti-core" name = "graphiti-core"
version = "0.7.0" version = "0.7.1"
description = "A temporal graph building library" description = "A temporal graph building library"
authors = [ authors = [
"Paul Paliychuk <paul@getzep.com>", "Paul Paliychuk <paul@getzep.com>",

View file

@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from graphiti_core.nodes import EntityNode from graphiti_core.nodes import EntityNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import hybrid_node_search from graphiti_core.search.search_utils import hybrid_node_search
@ -13,7 +14,7 @@ async def test_hybrid_node_search_deduplication():
# Mock the node_fulltext_search and entity_similarity_search functions # Mock the node_fulltext_search and entity_similarity_search functions
with patch( with patch(
'graphiti_core.search.search_utils.node_fulltext_search' 'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch( ) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search' 'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
@ -30,7 +31,7 @@ async def test_hybrid_node_search_deduplication():
# Call the function with test data # Call the function with test data
queries = ['Alice', 'Bob'] queries = ['Alice', 'Bob']
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
results = await hybrid_node_search(queries, embeddings, mock_driver) results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
# Assertions # Assertions
assert len(results) == 3 assert len(results) == 3
@ -47,7 +48,7 @@ async def test_hybrid_node_search_empty_results():
mock_driver = AsyncMock() mock_driver = AsyncMock()
with patch( with patch(
'graphiti_core.search.search_utils.node_fulltext_search' 'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch( ) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search' 'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
@ -56,7 +57,7 @@ async def test_hybrid_node_search_empty_results():
queries = ['NonExistent'] queries = ['NonExistent']
embeddings = [[0.1, 0.2, 0.3]] embeddings = [[0.1, 0.2, 0.3]]
results = await hybrid_node_search(queries, embeddings, mock_driver) results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
assert len(results) == 0 assert len(results) == 0
@ -66,7 +67,7 @@ async def test_hybrid_node_search_only_fulltext():
mock_driver = AsyncMock() mock_driver = AsyncMock()
with patch( with patch(
'graphiti_core.search.search_utils.node_fulltext_search' 'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch( ) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search' 'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
@ -77,7 +78,7 @@ async def test_hybrid_node_search_only_fulltext():
queries = ['Alice'] queries = ['Alice']
embeddings = [] embeddings = []
results = await hybrid_node_search(queries, embeddings, mock_driver) results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
assert len(results) == 1 assert len(results) == 1
assert results[0].name == 'Alice' assert results[0].name == 'Alice'
@ -90,7 +91,7 @@ async def test_hybrid_node_search_with_limit():
mock_driver = AsyncMock() mock_driver = AsyncMock()
with patch( with patch(
'graphiti_core.search.search_utils.node_fulltext_search' 'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch( ) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search' 'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
@ -111,7 +112,9 @@ async def test_hybrid_node_search_with_limit():
queries = ['Test'] queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]] embeddings = [[0.1, 0.2, 0.3]]
limit = 1 limit = 1
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) results = await hybrid_node_search(
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
)
# We expect 4 results because the limit is applied per search method # We expect 4 results because the limit is applied per search method
# before deduplication, and we're not actually limiting the results # before deduplication, and we're not actually limiting the results
@ -120,8 +123,10 @@ async def test_hybrid_node_search_with_limit():
assert mock_fulltext_search.call_count == 1 assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1 assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions # Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2) mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 2)
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2) mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 2
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -129,7 +134,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
mock_driver = AsyncMock() mock_driver = AsyncMock()
with patch( with patch(
'graphiti_core.search.search_utils.node_fulltext_search' 'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch( ) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search' 'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
@ -145,7 +150,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
queries = ['Test'] queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]] embeddings = [[0.1, 0.2, 0.3]]
limit = 2 limit = 2
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) results = await hybrid_node_search(
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
)
# We expect 3 results because: # We expect 3 results because:
# 1. The limit of 2 is applied to each search method # 1. The limit of 2 is applied to each search method
@ -155,5 +162,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1 assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1 assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4) mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 4)
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4) mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
)