search labels
This commit is contained in:
parent
97b0bbc7a8
commit
cbc1946346
2 changed files with 32 additions and 19 deletions
|
|
@ -200,30 +200,22 @@ class NeptuneDriver(GraphDriver):
|
||||||
else:
|
else:
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
if isinstance(v, datetime.datetime):
|
if isinstance(v, datetime.datetime):
|
||||||
|
# Convert datetime to ISO format string
|
||||||
params[k] = v.isoformat()
|
params[k] = v.isoformat()
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
# Handle lists that might contain datetime objects
|
# Check if list contains actual datetime objects (not just strings with 'T')
|
||||||
|
has_datetime = any(isinstance(item, datetime.datetime) for item in v)
|
||||||
|
|
||||||
|
if has_datetime:
|
||||||
|
# Convert datetime objects to ISO strings
|
||||||
|
for i, item in enumerate(v):
|
||||||
|
if isinstance(item, datetime.datetime):
|
||||||
|
v[i] = item.isoformat()
|
||||||
|
|
||||||
|
# Handle nested dictionaries
|
||||||
for i, item in enumerate(v):
|
for i, item in enumerate(v):
|
||||||
if isinstance(item, datetime.datetime):
|
|
||||||
v[i] = item.isoformat()
|
|
||||||
query = str(query).replace(f'${k}', f'datetime(${k})')
|
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
query = self._sanitize_parameters(query, v[i])
|
query = self._sanitize_parameters(query, v[i])
|
||||||
|
|
||||||
# If the list contains datetime objects, we need to wrap each element with datetime()
|
|
||||||
if any(isinstance(item, str) and 'T' in item for item in v):
|
|
||||||
# Create a new list expression with datetime() wrapped around each element
|
|
||||||
datetime_list = (
|
|
||||||
'['
|
|
||||||
+ ', '.join(
|
|
||||||
f'datetime("{item}")'
|
|
||||||
if isinstance(item, str) and 'T' in item
|
|
||||||
else repr(item)
|
|
||||||
for item in v
|
|
||||||
)
|
|
||||||
+ ']'
|
|
||||||
)
|
|
||||||
query = str(query).replace(f'${k}', datetime_list)
|
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, dict):
|
||||||
query = self._sanitize_parameters(query, v)
|
query = self._sanitize_parameters(query, v)
|
||||||
return query
|
return query
|
||||||
|
|
@ -232,6 +224,13 @@ class NeptuneDriver(GraphDriver):
|
||||||
self, cypher_query_, **kwargs: Any
|
self, cypher_query_, **kwargs: Any
|
||||||
) -> tuple[dict[str, Any], None, None]:
|
) -> tuple[dict[str, Any], None, None]:
|
||||||
params = dict(kwargs)
|
params = dict(kwargs)
|
||||||
|
|
||||||
|
# Flatten nested 'params' dict if present (for compatibility with Neo4j driver interface)
|
||||||
|
if 'params' in params and isinstance(params['params'], dict):
|
||||||
|
nested_params = params.pop('params')
|
||||||
|
# Merge nested params into the top level, nested params take precedence
|
||||||
|
params = {**params, **nested_params}
|
||||||
|
|
||||||
if isinstance(cypher_query_, list):
|
if isinstance(cypher_query_, list):
|
||||||
for q in cypher_query_:
|
for q in cypher_query_:
|
||||||
result, _, _ = self._run_query(q[0], q[1])
|
result, _, _ = self._run_query(q[0], q[1])
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,11 @@ def node_search_filter_query_constructor(
|
||||||
if provider == GraphProvider.KUZU:
|
if provider == GraphProvider.KUZU:
|
||||||
node_label_filter = 'list_has_all(n.labels, $labels)'
|
node_label_filter = 'list_has_all(n.labels, $labels)'
|
||||||
filter_params['labels'] = filters.node_labels
|
filter_params['labels'] = filters.node_labels
|
||||||
|
elif provider == GraphProvider.NEPTUNE:
|
||||||
|
# Neptune doesn't support pipe operator in WHERE clause
|
||||||
|
# Use OR with separate label checks instead
|
||||||
|
label_conditions = [f'n:{label}' for label in filters.node_labels]
|
||||||
|
node_label_filter = '(' + ' OR '.join(label_conditions) + ')'
|
||||||
else:
|
else:
|
||||||
node_labels = '|'.join(filters.node_labels)
|
node_labels = '|'.join(filters.node_labels)
|
||||||
node_label_filter = 'n:' + node_labels
|
node_label_filter = 'n:' + node_labels
|
||||||
|
|
@ -119,6 +124,15 @@ def edge_search_filter_query_constructor(
|
||||||
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
|
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
|
||||||
)
|
)
|
||||||
filter_params['labels'] = filters.node_labels
|
filter_params['labels'] = filters.node_labels
|
||||||
|
elif provider == GraphProvider.NEPTUNE:
|
||||||
|
# Neptune doesn't support pipe operator in WHERE clause
|
||||||
|
# Use OR with separate label checks instead
|
||||||
|
n_label_conditions = [f'n:{label}' for label in filters.node_labels]
|
||||||
|
m_label_conditions = [f'm:{label}' for label in filters.node_labels]
|
||||||
|
node_label_filter = (
|
||||||
|
'(' + ' OR '.join(n_label_conditions) + ') AND ('
|
||||||
|
+ ' OR '.join(m_label_conditions) + ')'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
node_labels = '|'.join(filters.node_labels)
|
node_labels = '|'.join(filters.node_labels)
|
||||||
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
|
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue