Merge branch 'main' into graphid-isolation

This commit is contained in:
Gal Shubeli 2025-08-14 15:57:02 +03:00 committed by GitHub
commit 21057a16e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 121 additions and 53 deletions

View file

@ -36,6 +36,9 @@ Join our [Discord server](https://discord.com/invite/W8Kw6bsgXQ) community and p
## What happens next? ## What happens next?
### Notes for Large Changes
> Please keep the changes as concise as possible. For major architectural changes (>500 LOC), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.
Once you've found an issue tagged with "good first issue" or "help wanted," or prepared an example to share, here's how to turn that into a contribution: Once you've found an issue tagged with "good first issue" or "help wanted," or prepared an example to share, here's how to turn that into a contribution:
1. Share your approach in the issue discussion or [Discord](https://discord.com/invite/W8Kw6bsgXQ) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework. 1. Share your approach in the issue discussion or [Discord](https://discord.com/invite/W8Kw6bsgXQ) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework.

View file

@ -157,7 +157,7 @@ class Graphiti:
If not set, the Graphiti default is used. If not set, the Graphiti default is used.
ensure_ascii : bool, optional ensure_ascii : bool, optional
Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False. Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False.
Set to False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their Set as False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their
original form, making them readable in LLM logs and improving model understanding. original form, making them readable in LLM logs and improving model understanding.
Returns Returns

View file

@ -178,31 +178,42 @@ async def edge_search(
) -> tuple[list[EntityEdge], list[float]]: ) -> tuple[list[EntityEdge], list[float]]:
if config is None: if config is None:
return [], [] return [], []
search_results: list[list[EntityEdge]] = list(
await semaphore_gather( # Build search tasks based on configured search methods
*[ search_tasks = []
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), if EdgeSearchMethod.bm25 in config.search_methods:
edge_similarity_search( search_tasks.append(
driver, edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
),
edge_bfs_search(
driver,
bfs_origin_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
),
]
) )
) if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
edge_similarity_search(
driver,
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
)
)
if EdgeSearchMethod.bfs in config.search_methods:
search_tasks.append(
edge_bfs_search(
driver,
bfs_origin_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
)
)
# Execute only the configured search methods
search_results: list[list[EntityEdge]] = []
if search_tasks:
search_results = list(await semaphore_gather(*search_tasks))
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result] source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
@ -290,24 +301,35 @@ async def node_search(
) -> tuple[list[EntityNode], list[float]]: ) -> tuple[list[EntityNode], list[float]]:
if config is None: if config is None:
return [], [] return [], []
search_results: list[list[EntityNode]] = list(
await semaphore_gather( # Build search tasks based on configured search methods
*[ search_tasks = []
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), if NodeSearchMethod.bm25 in config.search_methods:
node_similarity_search( search_tasks.append(
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
),
node_bfs_search(
driver,
bfs_origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
),
]
) )
) if NodeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
node_similarity_search(
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
)
)
if NodeSearchMethod.bfs in config.search_methods:
search_tasks.append(
node_bfs_search(
driver,
bfs_origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
)
)
# Execute only the configured search methods
search_results: list[list[EntityNode]] = []
if search_tasks:
search_results = list(await semaphore_gather(*search_tasks))
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result] origin_node_uuids = [node.uuid for result in search_results for node in result]

View file

@ -28,6 +28,8 @@ class ComparisonOperator(Enum):
less_than = '<' less_than = '<'
greater_than_equal = '>=' greater_than_equal = '>='
less_than_equal = '<=' less_than_equal = '<='
is_null = 'IS NULL'
is_not_null = 'IS NOT NULL'
class DateFilter(BaseModel): class DateFilter(BaseModel):
@ -64,6 +66,19 @@ def node_search_filter_query_constructor(
return filter_query, filter_params return filter_query, filter_params
def date_filter_query_constructor(
value_name: str, param_name: str, operator: ComparisonOperator
) -> str:
query = '(' + value_name + ' '
if operator == ComparisonOperator.is_null or operator == ComparisonOperator.is_not_null:
query += operator.value + ')'
else:
query += operator.value + ' ' + param_name + ')'
return query
def edge_search_filter_query_constructor( def edge_search_filter_query_constructor(
filters: SearchFilters, filters: SearchFilters,
) -> tuple[str, dict[str, Any]]: ) -> tuple[str, dict[str, Any]]:
@ -85,10 +100,16 @@ def edge_search_filter_query_constructor(
valid_at_filter = '\nAND (' valid_at_filter = '\nAND ('
for i, or_list in enumerate(filters.valid_at): for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
filter_params['valid_at_' + str(j)] = date_filter.date if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['valid_at_' + str(j)] = date_filter.date
and_filters = [ and_filters = [
'(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})' date_filter_query_constructor(
'e.valid_at', f'$valid_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list) for j, date_filter in enumerate(or_list)
] ]
and_filter_query = '' and_filter_query = ''
@ -110,10 +131,16 @@ def edge_search_filter_query_constructor(
invalid_at_filter = ' AND (' invalid_at_filter = ' AND ('
for i, or_list in enumerate(filters.invalid_at): for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
filter_params['invalid_at_' + str(j)] = date_filter.date if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['invalid_at_' + str(j)] = date_filter.date
and_filters = [ and_filters = [
'(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})' date_filter_query_constructor(
'e.invalid_at', f'$invalid_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list) for j, date_filter in enumerate(or_list)
] ]
and_filter_query = '' and_filter_query = ''
@ -135,10 +162,16 @@ def edge_search_filter_query_constructor(
created_at_filter = ' AND (' created_at_filter = ' AND ('
for i, or_list in enumerate(filters.created_at): for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
filter_params['created_at_' + str(j)] = date_filter.date if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['created_at_' + str(j)] = date_filter.date
and_filters = [ and_filters = [
'(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})' date_filter_query_constructor(
'e.created_at', f'$created_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list) for j, date_filter in enumerate(or_list)
] ]
and_filter_query = '' and_filter_query = ''
@ -160,10 +193,16 @@ def edge_search_filter_query_constructor(
expired_at_filter = ' AND (' expired_at_filter = ' AND ('
for i, or_list in enumerate(filters.expired_at): for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list): for j, date_filter in enumerate(or_list):
filter_params['expired_at_' + str(j)] = date_filter.date if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['expired_at_' + str(j)] = date_filter.date
and_filters = [ and_filters = [
'(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})' date_filter_query_constructor(
'e.expired_at', f'$expired_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list) for j, date_filter in enumerate(or_list)
] ]
and_filter_query = '' and_filter_query = ''

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.18.5" version = "0.18.6"
authors = [ authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -65,7 +65,11 @@ async def test_graphiti_init(driver):
search_filter = SearchFilters( search_filter = SearchFilters(
node_labels=['Person', 'City'], node_labels=['Person', 'City'],
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]], created_at=[
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)],
[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)],
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_not_null)],
],
) )
results = await graphiti.search_( results = await graphiti.search_(

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.18.5" version = "0.18.6"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },