Date filters (#240)

* add search filters

* add search filters

* mypy

* mypy

* update filtering

* date-filters

* update

* update filter queries

* update dictionary
This commit is contained in:
Preston Rasmussen 2025-01-28 11:52:53 -05:00 committed by GitHub
parent d3b2cecbe5
commit 6ef2f5e097
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 225 additions and 27 deletions

View file

@ -35,6 +35,7 @@ from graphiti_core.search.search_config_recipes import (
EDGE_HYBRID_SEARCH_NODE_DISTANCE, EDGE_HYBRID_SEARCH_NODE_DISTANCE,
EDGE_HYBRID_SEARCH_RRF, EDGE_HYBRID_SEARCH_RRF,
) )
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import ( from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT, RELEVANT_SCHEMA_LIMIT,
get_communities_by_nodes, get_communities_by_nodes,
@ -625,6 +626,7 @@ class Graphiti:
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT, num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
""" """
Perform a hybrid search on the knowledge graph. Perform a hybrid search on the knowledge graph.
@ -670,6 +672,7 @@ class Graphiti:
query, query,
group_ids, group_ids,
search_config, search_config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid, center_node_uuid,
) )
).edges ).edges
@ -683,6 +686,7 @@ class Graphiti:
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None,
) -> SearchResults: ) -> SearchResults:
return await search( return await search(
self.driver, self.driver,
@ -691,6 +695,7 @@ class Graphiti:
query, query,
group_ids, group_ids,
config, config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids, bfs_origin_node_uuids,
) )

View file

@ -39,6 +39,7 @@ from graphiti_core.search.search_config import (
SearchConfig, SearchConfig,
SearchResults, SearchResults,
) )
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import ( from graphiti_core.search.search_utils import (
community_fulltext_search, community_fulltext_search,
community_similarity_search, community_similarity_search,
@ -64,6 +65,7 @@ async def search(
query: str, query: str,
group_ids: list[str] | None, group_ids: list[str] | None,
config: SearchConfig, config: SearchConfig,
search_filter: SearchFilters,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults: ) -> SearchResults:
@ -86,6 +88,7 @@ async def search(
query_vector, query_vector,
group_ids, group_ids,
config.edge_config, config.edge_config,
search_filter,
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids, bfs_origin_node_uuids,
config.limit, config.limit,
@ -133,6 +136,7 @@ async def edge_search(
query_vector: list[float], query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: EdgeSearchConfig | None, config: EdgeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
@ -143,11 +147,20 @@ async def edge_search(
search_results: list[list[EntityEdge]] = list( search_results: list[list[EntityEdge]] = list(
await semaphore_gather( await semaphore_gather(
*[ *[
edge_fulltext_search(driver, query, group_ids, 2 * limit), edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
edge_similarity_search( edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score driver,
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
),
edge_bfs_search(
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
), ),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
] ]
) )
) )
@ -155,7 +168,9 @@ async def edge_search(
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result] source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append( search_results.append(
await edge_bfs_search(driver, source_node_uuids, config.bfs_max_depth, 2 * limit) await edge_bfs_search(
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
)
) )
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}

View file

@ -0,0 +1,152 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
class ComparisonOperator(Enum):
equals = '='
not_equals = '<>'
greater_than = '>'
less_than = '<'
greater_than_equal = '>='
less_than_equal = '<='
class DateFilter(BaseModel):
date: datetime = Field(description='A datetime to filter on')
comparison_operator: ComparisonOperator = Field(
description='Comparison operator for date filter'
)
class SearchFilters(BaseModel):
valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)
def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}
if filters.valid_at is not None:
valid_at_filter = 'AND ('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
filter_params['valid_at_' + str(j)] = date_filter.date
and_filters = [
'(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
valid_at_filter += and_filter_query
if i == len(or_list) - 1:
valid_at_filter += ')'
else:
valid_at_filter += ' OR '
filter_query += valid_at_filter
if filters.invalid_at is not None:
invalid_at_filter = 'AND ('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
filter_params['invalid_at_' + str(j)] = date_filter.date
and_filters = [
'(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
invalid_at_filter += and_filter_query
if i == len(or_list) - 1:
invalid_at_filter += ')'
else:
invalid_at_filter += ' OR '
filter_query += invalid_at_filter
if filters.created_at is not None:
created_at_filter = 'AND ('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
filter_params['created_at_' + str(j)] = date_filter.date
and_filters = [
'(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
created_at_filter += and_filter_query
if i == len(or_list) - 1:
created_at_filter += ')'
else:
created_at_filter += ' OR '
filter_query += created_at_filter
if filters.expired_at is not None:
expired_at_filter = 'AND ('
for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list):
filter_params['expired_at_' + str(j)] = date_filter.date
and_filters = [
'(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '
expired_at_filter += and_filter_query
if i == len(or_list) - 1:
expired_at_filter += ')'
else:
expired_at_filter += ' OR '
filter_query += expired_at_filter
return filter_query, filter_params

View file

@ -38,6 +38,7 @@ from graphiti_core.nodes import (
get_community_node_from_record, get_community_node_from_record,
get_entity_node_from_record, get_entity_node_from_record,
) )
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -136,6 +137,7 @@ async def get_communities_by_nodes(
async def edge_fulltext_search( async def edge_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
@ -144,28 +146,36 @@ async def edge_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
cypher_query = Query(""" filter_query, filter_params = search_filter_query_constructor(search_filter)
cypher_query = Query(
"""
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit}) CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
YIELD relationship AS r, score YIELD relationship AS rel, score
WITH r, score, startNode(r) AS n, endNode(r) AS m MATCH (:ENTITY)-[r:RELATES_TO]->(:ENTITY)
RETURN WHERE r.group_id IN $group_ids"""
r.uuid AS uuid, + filter_query
r.group_id AS group_id, + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
n.uuid AS source_node_uuid, RETURN
m.uuid AS target_node_uuid, r.uuid AS uuid,
r.created_at AS created_at, r.group_id AS group_id,
r.name AS name, n.uuid AS source_node_uuid,
r.fact AS fact, m.uuid AS target_node_uuid,
r.fact_embedding AS fact_embedding, r.created_at AS created_at,
r.episodes AS episodes, r.name AS name,
r.expired_at AS expired_at, r.fact AS fact,
r.valid_at AS valid_at, r.fact_embedding AS fact_embedding,
r.invalid_at AS invalid_at r.episodes AS episodes,
ORDER BY score DESC LIMIT $limit r.expired_at AS expired_at,
""") r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
"""
)
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
cypher_query, cypher_query,
filter_params,
query=fuzzy_query, query=fuzzy_query,
group_ids=group_ids, group_ids=group_ids,
limit=limit, limit=limit,
@ -183,6 +193,7 @@ async def edge_similarity_search(
search_vector: list[float], search_vector: list[float],
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
search_filter: SearchFilters,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE, min_score: float = DEFAULT_MIN_SCORE,
@ -194,6 +205,9 @@ async def edge_similarity_search(
query_params: dict[str, Any] = {} query_params: dict[str, Any] = {}
filter_query, filter_params = search_filter_query_constructor(search_filter)
query_params.update(filter_params)
group_filter_query: LiteralString = '' group_filter_query: LiteralString = ''
if group_ids is not None: if group_ids is not None:
group_filter_query += 'WHERE r.group_id IN $group_ids' group_filter_query += 'WHERE r.group_id IN $group_ids'
@ -209,9 +223,10 @@ async def edge_similarity_search(
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
""" """
+ group_filter_query + group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score WHERE score > $min_score
RETURN RETURN
@ -254,17 +269,25 @@ async def edge_bfs_search(
driver: AsyncDriver, driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None, bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int, bfs_max_depth: int,
search_filter: SearchFilters,
limit: int, limit: int,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
if bfs_origin_node_uuids is None: if bfs_origin_node_uuids is None:
return [] return []
query = Query(""" filter_query, filter_params = search_filter_query_constructor(search_filter)
query = Query(
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel UNWIND relationships(path) AS rel
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-() MATCH ()-[r:RELATES_TO]-()
WHERE r.uuid = rel.uuid
"""
+ filter_query
+ """
RETURN DISTINCT RETURN DISTINCT
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -279,10 +302,12 @@ async def edge_bfs_search(
r.valid_at AS valid_at, r.valid_at AS valid_at,
r.invalid_at AS invalid_at r.invalid_at AS invalid_at
LIMIT $limit LIMIT $limit
""") """
)
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
query, query,
filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids, bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth, depth=bfs_max_depth,
limit=limit, limit=limit,
@ -626,6 +651,7 @@ async def get_relevant_edges(
edge.fact_embedding, edge.fact_embedding,
source_node_uuid, source_node_uuid,
target_node_uuid, target_node_uuid,
SearchFilters(),
[edge.group_id], [edge.group_id],
limit, limit,
) )