node label filters (#265)
* node label filters * update * add search filters * updates * bump versions * update tests * test update
This commit is contained in:
parent
29a071b2b8
commit
088029a80c
7 changed files with 120 additions and 47 deletions
|
|
@ -351,7 +351,10 @@ class Graphiti:
|
|||
# Find relevant nodes already in the graph
|
||||
existing_nodes_lists: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
|
||||
*[
|
||||
get_relevant_nodes(self.driver, SearchFilters(), [node])
|
||||
for node in extracted_nodes
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -732,8 +735,8 @@ class Graphiti:
|
|||
self.llm_client,
|
||||
[source_node, target_node],
|
||||
[
|
||||
await get_relevant_nodes(self.driver, [source_node]),
|
||||
await get_relevant_nodes(self.driver, [target_node]),
|
||||
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
|
||||
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ async def search(
|
|||
query_vector,
|
||||
group_ids,
|
||||
config.node_config,
|
||||
search_filter,
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
|
|
@ -233,6 +234,7 @@ async def node_search(
|
|||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig | None,
|
||||
search_filter: SearchFilters,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
|
|
@ -243,11 +245,13 @@ async def node_search(
|
|||
search_results: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
||||
node_similarity_search(
|
||||
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
|
||||
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
||||
),
|
||||
node_bfs_search(
|
||||
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
|
||||
),
|
||||
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
@ -255,7 +259,9 @@ async def node_search(
|
|||
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
||||
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
||||
search_results.append(
|
||||
await node_bfs_search(driver, origin_node_uuids, config.bfs_max_depth, 2 * limit)
|
||||
await node_bfs_search(
|
||||
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
|
||||
)
|
||||
)
|
||||
|
||||
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
||||
|
|
|
|||
|
|
@ -39,18 +39,37 @@ class DateFilter(BaseModel):
|
|||
|
||||
|
||||
class SearchFilters(BaseModel):
|
||||
node_labels: list[str] | None = Field(
|
||||
default=None, description='List of node labels to filter on'
|
||||
)
|
||||
valid_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
invalid_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
created_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
|
||||
|
||||
def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
|
||||
def node_search_filter_query_constructor(
|
||||
filters: SearchFilters,
|
||||
) -> tuple[LiteralString, dict[str, Any]]:
|
||||
filter_query: LiteralString = ''
|
||||
filter_params: dict[str, Any] = {}
|
||||
|
||||
if filters.node_labels is not None:
|
||||
node_labels = ':'.join(filters.node_labels)
|
||||
node_label_filter = ' AND n:' + node_labels
|
||||
filter_query += node_label_filter
|
||||
|
||||
return filter_query, filter_params
|
||||
|
||||
|
||||
def edge_search_filter_query_constructor(
|
||||
filters: SearchFilters,
|
||||
) -> tuple[LiteralString, dict[str, Any]]:
|
||||
filter_query: LiteralString = ''
|
||||
filter_params: dict[str, Any] = {}
|
||||
|
||||
if filters.valid_at is not None:
|
||||
valid_at_filter = 'AND ('
|
||||
valid_at_filter = ' AND ('
|
||||
for i, or_list in enumerate(filters.valid_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
filter_params['valid_at_' + str(j)] = date_filter.date
|
||||
|
|
@ -75,7 +94,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
|
|||
filter_query += valid_at_filter
|
||||
|
||||
if filters.invalid_at is not None:
|
||||
invalid_at_filter = 'AND ('
|
||||
invalid_at_filter = ' AND ('
|
||||
for i, or_list in enumerate(filters.invalid_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
filter_params['invalid_at_' + str(j)] = date_filter.date
|
||||
|
|
@ -100,7 +119,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
|
|||
filter_query += invalid_at_filter
|
||||
|
||||
if filters.created_at is not None:
|
||||
created_at_filter = 'AND ('
|
||||
created_at_filter = ' AND ('
|
||||
for i, or_list in enumerate(filters.created_at):
|
||||
for j, date_filter in enumerate(or_list):
|
||||
filter_params['created_at_' + str(j)] = date_filter.date
|
||||
|
|
|
|||
|
|
@ -38,7 +38,11 @@ from graphiti_core.nodes import (
|
|||
get_community_node_from_record,
|
||||
get_entity_node_from_record,
|
||||
)
|
||||
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
|
||||
from graphiti_core.search.search_filters import (
|
||||
SearchFilters,
|
||||
edge_search_filter_query_constructor,
|
||||
node_search_filter_query_constructor,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -148,7 +152,7 @@ async def edge_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
filter_query, filter_params = search_filter_query_constructor(search_filter)
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
||||
cypher_query = Query(
|
||||
"""
|
||||
|
|
@ -207,7 +211,7 @@ async def edge_similarity_search(
|
|||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
filter_query, filter_params = search_filter_query_constructor(search_filter)
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
group_filter_query: LiteralString = ''
|
||||
|
|
@ -225,8 +229,8 @@ async def edge_similarity_search(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
|
|
@ -278,7 +282,7 @@ async def edge_bfs_search(
|
|||
if bfs_origin_node_uuids is None:
|
||||
return []
|
||||
|
||||
filter_query, filter_params = search_filter_query_constructor(search_filter)
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
||||
query = Query(
|
||||
"""
|
||||
|
|
@ -325,6 +329,7 @@ async def edge_bfs_search(
|
|||
async def node_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
|
|
@ -333,10 +338,17 @@ async def node_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
YIELD node AS node, score
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid = node.uuid
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
|
|
@ -349,6 +361,7 @@ async def node_fulltext_search(
|
|||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
filter_params,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
@ -363,6 +376,7 @@ async def node_fulltext_search(
|
|||
async def node_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
|
|
@ -379,12 +393,16 @@ async def node_similarity_search(
|
|||
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
runtime_query
|
||||
+ """
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
|
|
@ -416,6 +434,7 @@ async def node_similarity_search(
|
|||
async def node_bfs_search(
|
||||
driver: AsyncDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
search_filter: SearchFilters,
|
||||
bfs_max_depth: int,
|
||||
limit: int,
|
||||
) -> list[EntityNode]:
|
||||
|
|
@ -423,21 +442,28 @@ async def node_bfs_search(
|
|||
if bfs_origin_node_uuids is None:
|
||||
return []
|
||||
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
LIMIT $limit
|
||||
""",
|
||||
WHERE n.group_id = origin.group_id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
LIMIT $limit
|
||||
""",
|
||||
filter_params,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
|
|
@ -539,6 +565,7 @@ async def hybrid_node_search(
|
|||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
driver: AsyncDriver,
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
|
|
@ -583,8 +610,14 @@ async def hybrid_node_search(
|
|||
start = time()
|
||||
results: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
||||
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
||||
*[
|
||||
node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
|
||||
for q in queries
|
||||
],
|
||||
*[
|
||||
node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
|
||||
for e in embeddings
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -604,6 +637,7 @@ async def hybrid_node_search(
|
|||
|
||||
async def get_relevant_nodes(
|
||||
driver: AsyncDriver,
|
||||
search_filter: SearchFilters,
|
||||
nodes: list[EntityNode],
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
|
|
@ -635,6 +669,7 @@ async def get_relevant_nodes(
|
|||
[node.name for node in nodes],
|
||||
[node.name_embedding for node in nodes if node.name_embedding is not None],
|
||||
driver,
|
||||
search_filter,
|
||||
[node.group_id for node in nodes],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|||
EPISODIC_NODE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
|
|
@ -188,7 +189,7 @@ async def dedupe_nodes_bulk(
|
|||
|
||||
existing_nodes_chunks: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
|
||||
*[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.7.0"
|
||||
version = "0.7.1"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from graphiti_core.nodes import EntityNode
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import hybrid_node_search
|
||||
|
||||
|
||||
|
|
@ -13,7 +14,7 @@ async def test_hybrid_node_search_deduplication():
|
|||
|
||||
# Mock the node_fulltext_search and entity_similarity_search functions
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
|
|
@ -30,7 +31,7 @@ async def test_hybrid_node_search_deduplication():
|
|||
# Call the function with test data
|
||||
queries = ['Alice', 'Bob']
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver)
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
|
||||
|
||||
# Assertions
|
||||
assert len(results) == 3
|
||||
|
|
@ -47,7 +48,7 @@ async def test_hybrid_node_search_empty_results():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
|
|
@ -56,7 +57,7 @@ async def test_hybrid_node_search_empty_results():
|
|||
|
||||
queries = ['NonExistent']
|
||||
embeddings = [[0.1, 0.2, 0.3]]
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver)
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ async def test_hybrid_node_search_only_fulltext():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
|
|
@ -77,7 +78,7 @@ async def test_hybrid_node_search_only_fulltext():
|
|||
|
||||
queries = ['Alice']
|
||||
embeddings = []
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver)
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].name == 'Alice'
|
||||
|
|
@ -90,7 +91,7 @@ async def test_hybrid_node_search_with_limit():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
|
|
@ -111,7 +112,9 @@ async def test_hybrid_node_search_with_limit():
|
|||
queries = ['Test']
|
||||
embeddings = [[0.1, 0.2, 0.3]]
|
||||
limit = 1
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
|
||||
results = await hybrid_node_search(
|
||||
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
|
||||
)
|
||||
|
||||
# We expect 4 results because the limit is applied per search method
|
||||
# before deduplication, and we're not actually limiting the results
|
||||
|
|
@ -120,8 +123,10 @@ async def test_hybrid_node_search_with_limit():
|
|||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
# Verify that the limit was passed to the search functions
|
||||
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2)
|
||||
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2)
|
||||
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 2)
|
||||
mock_similarity_search.assert_called_with(
|
||||
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -129,7 +134,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
|||
mock_driver = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
'graphiti_core.search.search_utils.node_fulltext_search'
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.node_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
|
|
@ -145,7 +150,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
|||
queries = ['Test']
|
||||
embeddings = [[0.1, 0.2, 0.3]]
|
||||
limit = 2
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
|
||||
results = await hybrid_node_search(
|
||||
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
|
||||
)
|
||||
|
||||
# We expect 3 results because:
|
||||
# 1. The limit of 2 is applied to each search method
|
||||
|
|
@ -155,5 +162,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
|||
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
||||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4)
|
||||
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4)
|
||||
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 4)
|
||||
mock_similarity_search.assert_called_with(
|
||||
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue