Merge branch 'main' into graphid-isolation
This commit is contained in:
commit
21057a16e3
7 changed files with 121 additions and 53 deletions
|
|
@ -36,6 +36,9 @@ Join our [Discord server](https://discord.com/invite/W8Kw6bsgXQ) community and p
|
||||||
|
|
||||||
## What happens next?
|
## What happens next?
|
||||||
|
|
||||||
|
### Notes for Large Changes
|
||||||
|
> Please keep the changes as concise as possible. For major architectural changes (>500 LOC), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.
|
||||||
|
|
||||||
Once you've found an issue tagged with "good first issue" or "help wanted," or prepared an example to share, here's how to turn that into a contribution:
|
Once you've found an issue tagged with "good first issue" or "help wanted," or prepared an example to share, here's how to turn that into a contribution:
|
||||||
|
|
||||||
1. Share your approach in the issue discussion or [Discord](https://discord.com/invite/W8Kw6bsgXQ) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework.
|
1. Share your approach in the issue discussion or [Discord](https://discord.com/invite/W8Kw6bsgXQ) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework.
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,7 @@ class Graphiti:
|
||||||
If not set, the Graphiti default is used.
|
If not set, the Graphiti default is used.
|
||||||
ensure_ascii : bool, optional
|
ensure_ascii : bool, optional
|
||||||
Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False.
|
Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False.
|
||||||
Set to False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their
|
Set as False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their
|
||||||
original form, making them readable in LLM logs and improving model understanding.
|
original form, making them readable in LLM logs and improving model understanding.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
|
|
||||||
|
|
@ -178,31 +178,42 @@ async def edge_search(
|
||||||
) -> tuple[list[EntityEdge], list[float]]:
|
) -> tuple[list[EntityEdge], list[float]]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return [], []
|
return [], []
|
||||||
search_results: list[list[EntityEdge]] = list(
|
|
||||||
await semaphore_gather(
|
# Build search tasks based on configured search methods
|
||||||
*[
|
search_tasks = []
|
||||||
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
if EdgeSearchMethod.bm25 in config.search_methods:
|
||||||
edge_similarity_search(
|
search_tasks.append(
|
||||||
driver,
|
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
||||||
query_vector,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
search_filter,
|
|
||||||
group_ids,
|
|
||||||
2 * limit,
|
|
||||||
config.sim_min_score,
|
|
||||||
),
|
|
||||||
edge_bfs_search(
|
|
||||||
driver,
|
|
||||||
bfs_origin_node_uuids,
|
|
||||||
config.bfs_max_depth,
|
|
||||||
search_filter,
|
|
||||||
group_ids,
|
|
||||||
2 * limit,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
)
|
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
|
search_tasks.append(
|
||||||
|
edge_similarity_search(
|
||||||
|
driver,
|
||||||
|
query_vector,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
search_filter,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
|
config.sim_min_score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if EdgeSearchMethod.bfs in config.search_methods:
|
||||||
|
search_tasks.append(
|
||||||
|
edge_bfs_search(
|
||||||
|
driver,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
|
config.bfs_max_depth,
|
||||||
|
search_filter,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute only the configured search methods
|
||||||
|
search_results: list[list[EntityEdge]] = []
|
||||||
|
if search_tasks:
|
||||||
|
search_results = list(await semaphore_gather(*search_tasks))
|
||||||
|
|
||||||
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
||||||
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
||||||
|
|
@ -290,24 +301,35 @@ async def node_search(
|
||||||
) -> tuple[list[EntityNode], list[float]]:
|
) -> tuple[list[EntityNode], list[float]]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return [], []
|
return [], []
|
||||||
search_results: list[list[EntityNode]] = list(
|
|
||||||
await semaphore_gather(
|
# Build search tasks based on configured search methods
|
||||||
*[
|
search_tasks = []
|
||||||
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
if NodeSearchMethod.bm25 in config.search_methods:
|
||||||
node_similarity_search(
|
search_tasks.append(
|
||||||
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
||||||
),
|
|
||||||
node_bfs_search(
|
|
||||||
driver,
|
|
||||||
bfs_origin_node_uuids,
|
|
||||||
search_filter,
|
|
||||||
config.bfs_max_depth,
|
|
||||||
group_ids,
|
|
||||||
2 * limit,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
)
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
|
search_tasks.append(
|
||||||
|
node_similarity_search(
|
||||||
|
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if NodeSearchMethod.bfs in config.search_methods:
|
||||||
|
search_tasks.append(
|
||||||
|
node_bfs_search(
|
||||||
|
driver,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
|
search_filter,
|
||||||
|
config.bfs_max_depth,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute only the configured search methods
|
||||||
|
search_results: list[list[EntityNode]] = []
|
||||||
|
if search_tasks:
|
||||||
|
search_results = list(await semaphore_gather(*search_tasks))
|
||||||
|
|
||||||
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
||||||
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,8 @@ class ComparisonOperator(Enum):
|
||||||
less_than = '<'
|
less_than = '<'
|
||||||
greater_than_equal = '>='
|
greater_than_equal = '>='
|
||||||
less_than_equal = '<='
|
less_than_equal = '<='
|
||||||
|
is_null = 'IS NULL'
|
||||||
|
is_not_null = 'IS NOT NULL'
|
||||||
|
|
||||||
|
|
||||||
class DateFilter(BaseModel):
|
class DateFilter(BaseModel):
|
||||||
|
|
@ -64,6 +66,19 @@ def node_search_filter_query_constructor(
|
||||||
return filter_query, filter_params
|
return filter_query, filter_params
|
||||||
|
|
||||||
|
|
||||||
|
def date_filter_query_constructor(
|
||||||
|
value_name: str, param_name: str, operator: ComparisonOperator
|
||||||
|
) -> str:
|
||||||
|
query = '(' + value_name + ' '
|
||||||
|
|
||||||
|
if operator == ComparisonOperator.is_null or operator == ComparisonOperator.is_not_null:
|
||||||
|
query += operator.value + ')'
|
||||||
|
else:
|
||||||
|
query += operator.value + ' ' + param_name + ')'
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
def edge_search_filter_query_constructor(
|
def edge_search_filter_query_constructor(
|
||||||
filters: SearchFilters,
|
filters: SearchFilters,
|
||||||
) -> tuple[str, dict[str, Any]]:
|
) -> tuple[str, dict[str, Any]]:
|
||||||
|
|
@ -85,10 +100,16 @@ def edge_search_filter_query_constructor(
|
||||||
valid_at_filter = '\nAND ('
|
valid_at_filter = '\nAND ('
|
||||||
for i, or_list in enumerate(filters.valid_at):
|
for i, or_list in enumerate(filters.valid_at):
|
||||||
for j, date_filter in enumerate(or_list):
|
for j, date_filter in enumerate(or_list):
|
||||||
filter_params['valid_at_' + str(j)] = date_filter.date
|
if date_filter.comparison_operator not in [
|
||||||
|
ComparisonOperator.is_null,
|
||||||
|
ComparisonOperator.is_not_null,
|
||||||
|
]:
|
||||||
|
filter_params['valid_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
|
date_filter_query_constructor(
|
||||||
|
'e.valid_at', f'$valid_at_{j}', date_filter.comparison_operator
|
||||||
|
)
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
@ -110,10 +131,16 @@ def edge_search_filter_query_constructor(
|
||||||
invalid_at_filter = ' AND ('
|
invalid_at_filter = ' AND ('
|
||||||
for i, or_list in enumerate(filters.invalid_at):
|
for i, or_list in enumerate(filters.invalid_at):
|
||||||
for j, date_filter in enumerate(or_list):
|
for j, date_filter in enumerate(or_list):
|
||||||
filter_params['invalid_at_' + str(j)] = date_filter.date
|
if date_filter.comparison_operator not in [
|
||||||
|
ComparisonOperator.is_null,
|
||||||
|
ComparisonOperator.is_not_null,
|
||||||
|
]:
|
||||||
|
filter_params['invalid_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
|
date_filter_query_constructor(
|
||||||
|
'e.invalid_at', f'$invalid_at_{j}', date_filter.comparison_operator
|
||||||
|
)
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
@ -135,10 +162,16 @@ def edge_search_filter_query_constructor(
|
||||||
created_at_filter = ' AND ('
|
created_at_filter = ' AND ('
|
||||||
for i, or_list in enumerate(filters.created_at):
|
for i, or_list in enumerate(filters.created_at):
|
||||||
for j, date_filter in enumerate(or_list):
|
for j, date_filter in enumerate(or_list):
|
||||||
filter_params['created_at_' + str(j)] = date_filter.date
|
if date_filter.comparison_operator not in [
|
||||||
|
ComparisonOperator.is_null,
|
||||||
|
ComparisonOperator.is_not_null,
|
||||||
|
]:
|
||||||
|
filter_params['created_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
|
date_filter_query_constructor(
|
||||||
|
'e.created_at', f'$created_at_{j}', date_filter.comparison_operator
|
||||||
|
)
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
@ -160,10 +193,16 @@ def edge_search_filter_query_constructor(
|
||||||
expired_at_filter = ' AND ('
|
expired_at_filter = ' AND ('
|
||||||
for i, or_list in enumerate(filters.expired_at):
|
for i, or_list in enumerate(filters.expired_at):
|
||||||
for j, date_filter in enumerate(or_list):
|
for j, date_filter in enumerate(or_list):
|
||||||
filter_params['expired_at_' + str(j)] = date_filter.date
|
if date_filter.comparison_operator not in [
|
||||||
|
ComparisonOperator.is_null,
|
||||||
|
ComparisonOperator.is_not_null,
|
||||||
|
]:
|
||||||
|
filter_params['expired_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
|
date_filter_query_constructor(
|
||||||
|
'e.expired_at', f'$expired_at_{j}', date_filter.comparison_operator
|
||||||
|
)
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.18.5"
|
version = "0.18.6"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,11 @@ async def test_graphiti_init(driver):
|
||||||
|
|
||||||
search_filter = SearchFilters(
|
search_filter = SearchFilters(
|
||||||
node_labels=['Person', 'City'],
|
node_labels=['Person', 'City'],
|
||||||
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]],
|
created_at=[
|
||||||
|
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_null)],
|
||||||
|
[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)],
|
||||||
|
[DateFilter(date=None, comparison_operator=ComparisonOperator.is_not_null)],
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await graphiti.search_(
|
results = await graphiti.search_(
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.18.5"
|
version = "0.18.6"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue