From cbc194634638de1f389434aaacf7f8cc301cf5ff Mon Sep 17 00:00:00 2001 From: Aidan Petti Date: Mon, 24 Nov 2025 22:31:18 -0700 Subject: [PATCH] search labels --- graphiti_core/driver/neptune_driver.py | 37 +++++++++++++------------- graphiti_core/search/search_filters.py | 14 ++++++++++ 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index bb802729..690dfb0d 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -200,30 +200,22 @@ class NeptuneDriver(GraphDriver): else: for k, v in params.items(): if isinstance(v, datetime.datetime): + # Convert datetime to ISO format string params[k] = v.isoformat() elif isinstance(v, list): - # Handle lists that might contain datetime objects + # Check if list contains actual datetime objects (not just strings with 'T') + has_datetime = any(isinstance(item, datetime.datetime) for item in v) + + if has_datetime: + # Convert datetime objects to ISO strings + for i, item in enumerate(v): + if isinstance(item, datetime.datetime): + v[i] = item.isoformat() + + # Handle nested dictionaries for i, item in enumerate(v): - if isinstance(item, datetime.datetime): - v[i] = item.isoformat() - query = str(query).replace(f'${k}', f'datetime(${k})') if isinstance(item, dict): query = self._sanitize_parameters(query, v[i]) - - # If the list contains datetime objects, we need to wrap each element with datetime() - if any(isinstance(item, str) and 'T' in item for item in v): - # Create a new list expression with datetime() wrapped around each element - datetime_list = ( - '[' - + ', '.join( - f'datetime("{item}")' - if isinstance(item, str) and 'T' in item - else repr(item) - for item in v - ) - + ']' - ) - query = str(query).replace(f'${k}', datetime_list) elif isinstance(v, dict): query = self._sanitize_parameters(query, v) return query @@ -232,6 +224,13 @@ class NeptuneDriver(GraphDriver): self, cypher_query_, **kwargs: Any ) -> tuple[dict[str, Any], None, None]: params = dict(kwargs) + + # Flatten nested 'params' dict if present (for compatibility with Neo4j driver interface) + if 'params' in params and isinstance(params['params'], dict): + nested_params = params.pop('params') + # Merge nested params into the top level, nested params take precedence + params = {**params, **nested_params} + if isinstance(cypher_query_, list): for q in cypher_query_: result, _, _ = self._run_query(q[0], q[1]) diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 1534b926..922b853b 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -76,6 +76,11 @@ def node_search_filter_query_constructor( if provider == GraphProvider.KUZU: node_label_filter = 'list_has_all(n.labels, $labels)' filter_params['labels'] = filters.node_labels + elif provider == GraphProvider.NEPTUNE: + # Neptune doesn't support pipe operator in WHERE clause + # Use OR with separate label checks instead + label_conditions = [f'n:{label}' for label in filters.node_labels] + node_label_filter = '(' + ' OR '.join(label_conditions) + ')' else: node_labels = '|'.join(filters.node_labels) node_label_filter = 'n:' + node_labels @@ -119,6 +124,15 @@ def edge_search_filter_query_constructor( 'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)' ) filter_params['labels'] = filters.node_labels + elif provider == GraphProvider.NEPTUNE: + # Neptune doesn't support pipe operator in WHERE clause + # Use OR with separate label checks instead + n_label_conditions = [f'n:{label}' for label in filters.node_labels] + m_label_conditions = [f'm:{label}' for label in filters.node_labels] + node_label_filter = ( + '(' + ' OR '.join(n_label_conditions) + ') AND (' + + ' OR '.join(m_label_conditions) + ')' + ) else: node_labels = '|'.join(filters.node_labels) node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels