search labels

This commit is contained in:
Aidan Petti 2025-11-24 22:31:18 -07:00
parent 97b0bbc7a8
commit cbc1946346
2 changed files with 32 additions and 19 deletions

View file

@ -200,30 +200,22 @@ class NeptuneDriver(GraphDriver):
else:
for k, v in params.items():
if isinstance(v, datetime.datetime):
# Convert datetime to ISO format string
params[k] = v.isoformat()
elif isinstance(v, list):
# Handle lists that might contain datetime objects
# Check if list contains actual datetime objects (not just strings with 'T')
has_datetime = any(isinstance(item, datetime.datetime) for item in v)
if has_datetime:
# Convert datetime objects to ISO strings
for i, item in enumerate(v):
if isinstance(item, datetime.datetime):
v[i] = item.isoformat()
query = str(query).replace(f'${k}', f'datetime(${k})')
# Handle nested dictionaries
for i, item in enumerate(v):
if isinstance(item, dict):
query = self._sanitize_parameters(query, v[i])
# If the list contains datetime objects, we need to wrap each element with datetime()
if any(isinstance(item, str) and 'T' in item for item in v):
# Create a new list expression with datetime() wrapped around each element
datetime_list = (
'['
+ ', '.join(
f'datetime("{item}")'
if isinstance(item, str) and 'T' in item
else repr(item)
for item in v
)
+ ']'
)
query = str(query).replace(f'${k}', datetime_list)
elif isinstance(v, dict):
query = self._sanitize_parameters(query, v)
return query
@ -232,6 +224,13 @@ class NeptuneDriver(GraphDriver):
self, cypher_query_, **kwargs: Any
) -> tuple[dict[str, Any], None, None]:
params = dict(kwargs)
# Flatten nested 'params' dict if present (for compatibility with Neo4j driver interface)
if 'params' in params and isinstance(params['params'], dict):
nested_params = params.pop('params')
# Merge nested params into the top level, nested params take precedence
params = {**params, **nested_params}
if isinstance(cypher_query_, list):
for q in cypher_query_:
result, _, _ = self._run_query(q[0], q[1])

View file

@ -76,6 +76,11 @@ def node_search_filter_query_constructor(
if provider == GraphProvider.KUZU:
node_label_filter = 'list_has_all(n.labels, $labels)'
filter_params['labels'] = filters.node_labels
elif provider == GraphProvider.NEPTUNE:
# Neptune doesn't support pipe operator in WHERE clause
# Use OR with separate label checks instead
label_conditions = [f'n:{label}' for label in filters.node_labels]
node_label_filter = '(' + ' OR '.join(label_conditions) + ')'
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels
@ -119,6 +124,15 @@ def edge_search_filter_query_constructor(
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
)
filter_params['labels'] = filters.node_labels
elif provider == GraphProvider.NEPTUNE:
# Neptune doesn't support pipe operator in WHERE clause
# Use OR with separate label checks instead
n_label_conditions = [f'n:{label}' for label in filters.node_labels]
m_label_conditions = [f'm:{label}' for label in filters.node_labels]
node_label_filter = (
'(' + ' OR '.join(n_label_conditions) + ') AND ('
+ ' OR '.join(m_label_conditions) + ')'
)
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels