node label filters (#265)

* node label filters

* update

* add search filters

* updates

* bump versions

* update tests

* test update
This commit is contained in:
Preston Rasmussen 2025-02-21 12:38:01 -05:00 committed by GitHub
parent 29a071b2b8
commit 088029a80c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 120 additions and 47 deletions

View file

@ -351,7 +351,10 @@ class Graphiti:
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
*[
get_relevant_nodes(self.driver, SearchFilters(), [node])
for node in extracted_nodes
]
)
)
@ -732,8 +735,8 @@ class Graphiti:
self.llm_client,
[source_node, target_node],
[
await get_relevant_nodes(self.driver, [source_node]),
await get_relevant_nodes(self.driver, [target_node]),
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
],
)

View file

@ -100,6 +100,7 @@ async def search(
query_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
@ -233,6 +234,7 @@ async def node_search(
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
@ -243,11 +245,13 @@ async def node_search(
search_results: list[list[EntityNode]] = list(
await semaphore_gather(
*[
node_fulltext_search(driver, query, group_ids, 2 * limit),
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
node_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
),
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)
@ -255,7 +259,9 @@ async def node_search(
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append(
await node_bfs_search(driver, origin_node_uuids, config.bfs_max_depth, 2 * limit)
await node_bfs_search(
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
)
)
search_result_uuids = [[node.uuid for node in result] for result in search_results]

View file

@ -39,18 +39,37 @@ class DateFilter(BaseModel):
class SearchFilters(BaseModel):
node_labels: list[str] | None = Field(
default=None, description='List of node labels to filter on'
)
valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)
def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
def node_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}
if filters.node_labels is not None:
node_labels = ':'.join(filters.node_labels)
node_label_filter = ' AND n:' + node_labels
filter_query += node_label_filter
return filter_query, filter_params
def edge_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}
if filters.valid_at is not None:
valid_at_filter = 'AND ('
valid_at_filter = ' AND ('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
filter_params['valid_at_' + str(j)] = date_filter.date
@ -75,7 +94,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
filter_query += valid_at_filter
if filters.invalid_at is not None:
invalid_at_filter = 'AND ('
invalid_at_filter = ' AND ('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
filter_params['invalid_at_' + str(j)] = date_filter.date
@ -100,7 +119,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
filter_query += invalid_at_filter
if filters.created_at is not None:
created_at_filter = 'AND ('
created_at_filter = ' AND ('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
filter_params['created_at_' + str(j)] = date_filter.date

View file

@ -38,7 +38,11 @@ from graphiti_core.nodes import (
get_community_node_from_record,
get_entity_node_from_record,
)
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
from graphiti_core.search.search_filters import (
SearchFilters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)
logger = logging.getLogger(__name__)
@ -148,7 +152,7 @@ async def edge_fulltext_search(
if fuzzy_query == '':
return []
filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
cypher_query = Query(
"""
@ -207,7 +211,7 @@ async def edge_similarity_search(
query_params: dict[str, Any] = {}
filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
group_filter_query: LiteralString = ''
@ -225,8 +229,8 @@ async def edge_similarity_search(
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -278,7 +282,7 @@ async def edge_bfs_search(
if bfs_origin_node_uuids is None:
return []
filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query = Query(
"""
@ -325,6 +329,7 @@ async def edge_bfs_search(
async def node_fulltext_search(
driver: AsyncDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
@ -333,10 +338,17 @@ async def node_fulltext_search(
if fuzzy_query == '':
return []
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
YIELD node AS node, score
MATCH (n:Entity)
WHERE n.uuid = node.uuid
"""
+ filter_query
+ """
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
@ -349,6 +361,7 @@ async def node_fulltext_search(
ORDER BY score DESC
LIMIT $limit
""",
filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
@ -363,6 +376,7 @@ async def node_fulltext_search(
async def node_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
@ -379,12 +393,16 @@ async def node_similarity_search(
group_filter_query += 'WHERE n.group_id IN $group_ids'
query_params['group_ids'] = group_ids
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
records, _, _ = await driver.execute_query(
runtime_query
+ """
MATCH (n:Entity)
"""
+ group_filter_query
+ filter_query
+ """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score
@ -416,6 +434,7 @@ async def node_similarity_search(
async def node_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters,
bfs_max_depth: int,
limit: int,
) -> list[EntityNode]:
@ -423,21 +442,28 @@ async def node_bfs_search(
if bfs_origin_node_uuids is None:
return []
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
records, _, _ = await driver.execute_query(
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
WHERE n.group_id = origin.group_id
"""
+ filter_query
+ """
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
@ -539,6 +565,7 @@ async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
@ -583,8 +610,14 @@ async def hybrid_node_search(
start = time()
results: list[list[EntityNode]] = list(
await semaphore_gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
*[
node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
for q in queries
],
*[
node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
for e in embeddings
],
)
)
@ -604,6 +637,7 @@ async def hybrid_node_search(
async def get_relevant_nodes(
driver: AsyncDriver,
search_filter: SearchFilters,
nodes: list[EntityNode],
) -> list[EntityNode]:
"""
@ -635,6 +669,7 @@ async def get_relevant_nodes(
[node.name for node in nodes],
[node.name_embedding for node in nodes if node.name_embedding is not None],
driver,
search_filter,
[node.group_id for node in nodes],
)

View file

@ -37,6 +37,7 @@ from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.edge_operations import (
@ -188,7 +189,7 @@ async def dedupe_nodes_bulk(
existing_nodes_chunks: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
*[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
)
)

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.7.0"
version = "0.7.1"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",

View file

@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from graphiti_core.nodes import EntityNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import hybrid_node_search
@ -13,7 +14,7 @@ async def test_hybrid_node_search_deduplication():
# Mock the node_fulltext_search and entity_similarity_search functions
with patch(
'graphiti_core.search.search_utils.node_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
@ -30,7 +31,7 @@ async def test_hybrid_node_search_deduplication():
# Call the function with test data
queries = ['Alice', 'Bob']
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
results = await hybrid_node_search(queries, embeddings, mock_driver)
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
# Assertions
assert len(results) == 3
@ -47,7 +48,7 @@ async def test_hybrid_node_search_empty_results():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.node_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
@ -56,7 +57,7 @@ async def test_hybrid_node_search_empty_results():
queries = ['NonExistent']
embeddings = [[0.1, 0.2, 0.3]]
results = await hybrid_node_search(queries, embeddings, mock_driver)
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
assert len(results) == 0
@ -66,7 +67,7 @@ async def test_hybrid_node_search_only_fulltext():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.node_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
@ -77,7 +78,7 @@ async def test_hybrid_node_search_only_fulltext():
queries = ['Alice']
embeddings = []
results = await hybrid_node_search(queries, embeddings, mock_driver)
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
assert len(results) == 1
assert results[0].name == 'Alice'
@ -90,7 +91,7 @@ async def test_hybrid_node_search_with_limit():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.node_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
@ -111,7 +112,9 @@ async def test_hybrid_node_search_with_limit():
queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]]
limit = 1
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
results = await hybrid_node_search(
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
)
# We expect 4 results because the limit is applied per search method
# before deduplication, and we're not actually limiting the results
@ -120,8 +123,10 @@ async def test_hybrid_node_search_with_limit():
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2)
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2)
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 2)
mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 2
)
@pytest.mark.asyncio
@ -129,7 +134,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.node_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
@ -145,7 +150,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]]
limit = 2
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
results = await hybrid_node_search(
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
)
# We expect 3 results because:
# 1. The limit of 2 is applied to each search method
@ -155,5 +162,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4)
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4)
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 4)
mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
)