search labels
This commit is contained in:
parent
97b0bbc7a8
commit
cbc1946346
2 changed files with 32 additions and 19 deletions
|
|
@ -200,30 +200,22 @@ class NeptuneDriver(GraphDriver):
|
|||
else:
|
||||
for k, v in params.items():
|
||||
if isinstance(v, datetime.datetime):
|
||||
# Convert datetime to ISO format string
|
||||
params[k] = v.isoformat()
|
||||
elif isinstance(v, list):
|
||||
# Handle lists that might contain datetime objects
|
||||
# Check if list contains actual datetime objects (not just strings with 'T')
|
||||
has_datetime = any(isinstance(item, datetime.datetime) for item in v)
|
||||
|
||||
if has_datetime:
|
||||
# Convert datetime objects to ISO strings
|
||||
for i, item in enumerate(v):
|
||||
if isinstance(item, datetime.datetime):
|
||||
v[i] = item.isoformat()
|
||||
|
||||
# Handle nested dictionaries
|
||||
for i, item in enumerate(v):
|
||||
if isinstance(item, datetime.datetime):
|
||||
v[i] = item.isoformat()
|
||||
query = str(query).replace(f'${k}', f'datetime(${k})')
|
||||
if isinstance(item, dict):
|
||||
query = self._sanitize_parameters(query, v[i])
|
||||
|
||||
# If the list contains datetime objects, we need to wrap each element with datetime()
|
||||
if any(isinstance(item, str) and 'T' in item for item in v):
|
||||
# Create a new list expression with datetime() wrapped around each element
|
||||
datetime_list = (
|
||||
'['
|
||||
+ ', '.join(
|
||||
f'datetime("{item}")'
|
||||
if isinstance(item, str) and 'T' in item
|
||||
else repr(item)
|
||||
for item in v
|
||||
)
|
||||
+ ']'
|
||||
)
|
||||
query = str(query).replace(f'${k}', datetime_list)
|
||||
elif isinstance(v, dict):
|
||||
query = self._sanitize_parameters(query, v)
|
||||
return query
|
||||
|
|
@ -232,6 +224,13 @@ class NeptuneDriver(GraphDriver):
|
|||
self, cypher_query_, **kwargs: Any
|
||||
) -> tuple[dict[str, Any], None, None]:
|
||||
params = dict(kwargs)
|
||||
|
||||
# Flatten nested 'params' dict if present (for compatibility with Neo4j driver interface)
|
||||
if 'params' in params and isinstance(params['params'], dict):
|
||||
nested_params = params.pop('params')
|
||||
# Merge nested params into the top level, nested params take precedence
|
||||
params = {**params, **nested_params}
|
||||
|
||||
if isinstance(cypher_query_, list):
|
||||
for q in cypher_query_:
|
||||
result, _, _ = self._run_query(q[0], q[1])
|
||||
|
|
|
|||
|
|
@ -76,6 +76,11 @@ def node_search_filter_query_constructor(
|
|||
if provider == GraphProvider.KUZU:
|
||||
node_label_filter = 'list_has_all(n.labels, $labels)'
|
||||
filter_params['labels'] = filters.node_labels
|
||||
elif provider == GraphProvider.NEPTUNE:
|
||||
# Neptune doesn't support pipe operator in WHERE clause
|
||||
# Use OR with separate label checks instead
|
||||
label_conditions = [f'n:{label}' for label in filters.node_labels]
|
||||
node_label_filter = '(' + ' OR '.join(label_conditions) + ')'
|
||||
else:
|
||||
node_labels = '|'.join(filters.node_labels)
|
||||
node_label_filter = 'n:' + node_labels
|
||||
|
|
@ -119,6 +124,15 @@ def edge_search_filter_query_constructor(
|
|||
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
|
||||
)
|
||||
filter_params['labels'] = filters.node_labels
|
||||
elif provider == GraphProvider.NEPTUNE:
|
||||
# Neptune doesn't support pipe operator in WHERE clause
|
||||
# Use OR with separate label checks instead
|
||||
n_label_conditions = [f'n:{label}' for label in filters.node_labels]
|
||||
m_label_conditions = [f'm:{label}' for label in filters.node_labels]
|
||||
node_label_filter = (
|
||||
'(' + ' OR '.join(n_label_conditions) + ') AND ('
|
||||
+ ' OR '.join(m_label_conditions) + ')'
|
||||
)
|
||||
else:
|
||||
node_labels = '|'.join(filters.node_labels)
|
||||
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue