diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a32d07c3..88044227 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,6 +36,9 @@ Join our [Discord server](https://discord.com/invite/W8Kw6bsgXQ) community and p ## What happens next? +### Notes for Large Changes +> Please keep the changes as concise as possible. For major architectural changes (>500 LOC), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR. + Once you've found an issue tagged with "good first issue" or "help wanted," or prepared an example to share, here's how to turn that into a contribution: 1. Share your approach in the issue discussion or [Discord](https://discord.com/invite/W8Kw6bsgXQ) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework. diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 3e875073..bbe80d5b 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -157,7 +157,7 @@ class Graphiti: If not set, the Graphiti default is used. ensure_ascii : bool, optional Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False. - Set to False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their + Set as False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their original form, making them readable in LLM logs and improving model understanding. Returns diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index a4e3b248..387c6183 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -178,31 +178,42 @@ async def edge_search( ) -> tuple[list[EntityEdge], list[float]]: if config is None: return [], [] - search_results: list[list[EntityEdge]] = list( - await semaphore_gather( - *[ - edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), - edge_similarity_search( - driver, - query_vector, - None, - None, - search_filter, - group_ids, - 2 * limit, - config.sim_min_score, - ), - edge_bfs_search( - driver, - bfs_origin_node_uuids, - config.bfs_max_depth, - search_filter, - group_ids, - 2 * limit, - ), - ] + + # Build search tasks based on configured search methods + search_tasks = [] + if EdgeSearchMethod.bm25 in config.search_methods: + search_tasks.append( + edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit) ) - ) + if EdgeSearchMethod.cosine_similarity in config.search_methods: + search_tasks.append( + edge_similarity_search( + driver, + query_vector, + None, + None, + search_filter, + group_ids, + 2 * limit, + config.sim_min_score, + ) + ) + if EdgeSearchMethod.bfs in config.search_methods: + search_tasks.append( + edge_bfs_search( + driver, + bfs_origin_node_uuids, + config.bfs_max_depth, + search_filter, + group_ids, + 2 * limit, + ) + ) + + # Execute only the configured search methods + search_results: list[list[EntityEdge]] = [] + if search_tasks: + search_results = list(await semaphore_gather(*search_tasks)) if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result] @@ -290,24 +301,35 @@ async def node_search( ) -> tuple[list[EntityNode], list[float]]: if config is None: return [], [] - search_results: list[list[EntityNode]] = list( - await semaphore_gather( - *[ - node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), - node_similarity_search( - driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score - ), - node_bfs_search( - driver, - bfs_origin_node_uuids, - search_filter, - config.bfs_max_depth, - group_ids, - 2 * limit, - ), - ] + + # Build search tasks based on configured search methods + search_tasks = [] + if NodeSearchMethod.bm25 in config.search_methods: + search_tasks.append( + node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit) ) - ) + if NodeSearchMethod.cosine_similarity in config.search_methods: + search_tasks.append( + node_similarity_search( + driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score + ) + ) + if NodeSearchMethod.bfs in config.search_methods: + search_tasks.append( + node_bfs_search( + driver, + bfs_origin_node_uuids, + search_filter, + config.bfs_max_depth, + group_ids, + 2 * limit, + ) + ) + + # Execute only the configured search methods + search_results: list[list[EntityNode]] = [] + if search_tasks: + search_results = list(await semaphore_gather(*search_tasks)) if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: origin_node_uuids = [node.uuid for result in search_results for node in result] diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 3a5d2f21..2213688b 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -28,6 +28,8 @@ class ComparisonOperator(Enum): less_than = '<' greater_than_equal = '>=' less_than_equal = '<=' + is_null = 'IS NULL' + is_not_null = 'IS NOT NULL' class DateFilter(BaseModel): @@ -64,6 +66,19 @@ def node_search_filter_query_constructor( return filter_query, filter_params +def date_filter_query_constructor( + value_name: str, param_name: str, operator: ComparisonOperator +) -> str: + query = '(' + value_name + ' ' + + if operator == ComparisonOperator.is_null or operator == ComparisonOperator.is_not_null: + query += operator.value + ')' + else: + query += operator.value + ' ' + param_name + ')' + + return query + + def edge_search_filter_query_constructor( filters: SearchFilters, ) -> tuple[str, dict[str, Any]]: @@ -85,10 +100,16 @@ def edge_search_filter_query_constructor( valid_at_filter = '\nAND (' for i, or_list in enumerate(filters.valid_at): for j, date_filter in enumerate(or_list): - filter_params['valid_at_' + str(j)] = date_filter.date + if date_filter.comparison_operator not in [ + ComparisonOperator.is_null, + ComparisonOperator.is_not_null, + ]: + filter_params['valid_at_' + str(j)] = date_filter.date and_filters = [ - '(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})' + date_filter_query_constructor( + 'e.valid_at', f'$valid_at_{j}', date_filter.comparison_operator + ) for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -110,10 +131,16 @@ def edge_search_filter_query_constructor( invalid_at_filter = ' AND (' for i, or_list in enumerate(filters.invalid_at): for j, date_filter in enumerate(or_list): - filter_params['invalid_at_' + str(j)] = date_filter.date + if date_filter.comparison_operator not in [ + ComparisonOperator.is_null, + ComparisonOperator.is_not_null, + ]: + filter_params['invalid_at_' + str(j)] = date_filter.date and_filters = [ - '(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})' + date_filter_query_constructor( + 'e.invalid_at', f'$invalid_at_{j}', date_filter.comparison_operator + ) for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -135,10 +162,16 @@ def edge_search_filter_query_constructor( created_at_filter = ' AND (' for i, or_list in enumerate(filters.created_at): for j, date_filter in enumerate(or_list): - filter_params['created_at_' + str(j)] = date_filter.date + if date_filter.comparison_operator not in [ + ComparisonOperator.is_null, + ComparisonOperator.is_not_null, + ]: + filter_params['created_at_' + str(j)] = date_filter.date and_filters = [ - '(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})' + date_filter_query_constructor( + 'e.created_at', f'$created_at_{j}', date_filter.comparison_operator + ) for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -160,10 +193,16 @@ def edge_search_filter_query_constructor( expired_at_filter = ' AND (' for i, or_list in enumerate(filters.expired_at): for j, date_filter in enumerate(or_list): - filter_params['expired_at_' + str(j)] = date_filter.date + if date_filter.comparison_operator not in [ + ComparisonOperator.is_null, + ComparisonOperator.is_not_null, + ]: + filter_params['expired_at_' + str(j)] = date_filter.date and_filters = [ - '(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})' + date_filter_query_constructor( + 'e.expired_at', f'$expired_at_{j}', date_filter.comparison_operator + ) for j, date_filter in enumerate(or_list) ] and_filter_query = '' diff --git a/pyproject.toml b/pyproject.toml index d0b864ae..cf476ef6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.18.5" +version = "0.18.6" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 165464ce..276191d0 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -65,7 +65,11 @@ async def test_graphiti_init(driver): search_filter = SearchFilters( node_labels=['Person', 'City'], - created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]], + created_at=[ + [DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)], + [DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)], + [DateFilter(date=None, comparison_operator=ComparisonOperator.is_not_null)], + ], ) results = await graphiti.search_( diff --git a/uv.lock b/uv.lock index 0f736407..00f70e1b 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.18.5" +version = "0.18.6" source = { editable = "." } dependencies = [ { name = "diskcache" },