Adds experimental support for Apache TinkerPop Gremlin as an alternative
query language for AWS Neptune Database, alongside the existing openCypher
support. This enables users to choose their preferred query language and
opens the door for future support of other Gremlin-compatible databases.
- QueryLanguage enum (CYPHER, GREMLIN) for explicit language selection
- Dual-mode NeptuneDriver supporting both Cypher and Gremlin
- Gremlin query generation functions for common graph operations
- Graceful degradation when gremlinpython is not installed
- 100% backward compatible (defaults to CYPHER)
- graphiti_core/driver/driver.py: Added QueryLanguage enum
- graphiti_core/driver/neptune_driver.py: Dual client initialization
and query routing based on language selection
- graphiti_core/graph_queries.py: 9 new Gremlin query generation functions
- graphiti_core/utils/maintenance/graph_data_operations.py: Updated
clear_data() to support both query languages
- tests/test_neptune_gremlin_int.py: Comprehensive integration tests
- examples/quickstart/quickstart_neptune_gremlin.py: Usage example
- examples/quickstart/README.md: Updated with Gremlin instructions
- GREMLIN_FEATURE.md: Complete feature documentation
- pyproject.toml: Added gremlinpython>=3.7.0 to neptune extras
```python
from graphiti_core.driver.driver import QueryLanguage
from graphiti_core.driver.neptune_driver import NeptuneDriver
driver = NeptuneDriver(
host='neptune-db://cluster.amazonaws.com',
aoss_host='aoss-cluster.amazonaws.com',
query_language=QueryLanguage.GREMLIN
)
```
- Only Neptune Database supports Gremlin (not Neptune Analytics)
- Fulltext and vector search still use OpenSearch (AOSS) integration
- Complete search_utils.py Gremlin implementation pending (future work)
- ✅ All existing unit tests pass (103/103)
- ✅ New integration tests for Gremlin operations
- ✅ Type checking passes
- ✅ Linting passes
None. Fully backward compatible.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
343 lines
13 KiB
Python
343 lines
13 KiB
Python
"""
|
|
Database query utilities for different graph database backends.
|
|
|
|
This module provides database-agnostic query generation for Neo4j, FalkorDB, Kuzu, and Neptune,
|
|
supporting index creation, fulltext search, bulk operations, and Gremlin queries.
|
|
"""
|
|
|
|
from typing_extensions import LiteralString
|
|
|
|
from graphiti_core.driver.driver import GraphProvider
|
|
|
|
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
|
NEO4J_TO_FALKORDB_MAPPING = {
|
|
'node_name_and_summary': 'Entity',
|
|
'community_name': 'Community',
|
|
'episode_content': 'Episodic',
|
|
'edge_name_and_fact': 'RELATES_TO',
|
|
}
|
|
# Mapping from fulltext index names to Kuzu node labels
|
|
INDEX_TO_LABEL_KUZU_MAPPING = {
|
|
'node_name_and_summary': 'Entity',
|
|
'community_name': 'Community',
|
|
'episode_content': 'Episodic',
|
|
'edge_name_and_fact': 'RelatesToNode_',
|
|
}
|
|
|
|
|
|
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
if provider == GraphProvider.FALKORDB:
|
|
return [
|
|
# Entity node
|
|
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
|
# Episodic node
|
|
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
|
|
# Community node
|
|
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
|
|
# RELATES_TO edge
|
|
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
|
|
# MENTIONS edge
|
|
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
|
|
# HAS_MEMBER edge
|
|
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
]
|
|
|
|
if provider == GraphProvider.KUZU:
|
|
return []
|
|
|
|
return [
|
|
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
'CREATE INDEX community_group_id IF NOT EXISTS FOR (n:Community) ON (n.group_id)',
|
|
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
]
|
|
|
|
|
|
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
if provider == GraphProvider.FALKORDB:
|
|
from typing import cast
|
|
|
|
from graphiti_core.driver.falkordb_driver import STOPWORDS
|
|
|
|
# Convert to string representation for embedding in queries
|
|
stopwords_str = str(STOPWORDS)
|
|
|
|
# Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
|
|
return cast(
|
|
list[LiteralString],
|
|
[
|
|
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
{{
|
|
label: 'Episodic',
|
|
stopwords: {stopwords_str}
|
|
}},
|
|
'content', 'source', 'source_description', 'group_id'
|
|
)""",
|
|
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
{{
|
|
label: 'Entity',
|
|
stopwords: {stopwords_str}
|
|
}},
|
|
'name', 'summary', 'group_id'
|
|
)""",
|
|
f"""CALL db.idx.fulltext.createNodeIndex(
|
|
{{
|
|
label: 'Community',
|
|
stopwords: {stopwords_str}
|
|
}},
|
|
'name', 'group_id'
|
|
)""",
|
|
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
],
|
|
)
|
|
|
|
if provider == GraphProvider.KUZU:
|
|
return [
|
|
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
|
|
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
|
|
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
|
|
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
|
|
]
|
|
|
|
return [
|
|
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
|
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
|
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
|
]
|
|
|
|
|
|
def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
|
|
if provider == GraphProvider.FALKORDB:
|
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
|
|
|
if provider == GraphProvider.KUZU:
|
|
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
|
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
|
|
|
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
|
|
|
|
|
def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|
if provider == GraphProvider.FALKORDB:
|
|
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
|
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
|
|
|
if provider == GraphProvider.KUZU:
|
|
return f'array_cosine_similarity({vec1}, {vec2})'
|
|
|
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
|
|
|
|
def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
|
|
if provider == GraphProvider.FALKORDB:
|
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
|
|
|
if provider == GraphProvider.KUZU:
|
|
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
|
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
|
|
|
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
|
|
|
|
|
# Gremlin Query Generation Functions
|
|
|
|
|
|
def gremlin_match_node_by_property(
|
|
label: str, property_name: str, property_value_param: str
|
|
) -> str:
|
|
"""
|
|
Generate a Gremlin query to match a node by label and property.
|
|
|
|
Args:
|
|
label: Node label (e.g., 'Entity', 'Episodic')
|
|
property_name: Property name to match on
|
|
property_value_param: Parameter name for the property value
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
return f"g.V().hasLabel('{label}').has('{property_name}', {property_value_param})"
|
|
|
|
|
|
def gremlin_match_nodes_by_uuids(label: str, uuids_param: str = 'uuids') -> str:
|
|
"""
|
|
Generate a Gremlin query to match multiple nodes by UUIDs.
|
|
|
|
Args:
|
|
label: Node label (e.g., 'Entity', 'Episodic')
|
|
uuids_param: Parameter name containing list of UUIDs
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
return f"g.V().hasLabel('{label}').has('uuid', within({uuids_param}))"
|
|
|
|
|
|
def gremlin_match_edge_by_property(
|
|
edge_label: str, property_name: str, property_value_param: str
|
|
) -> str:
|
|
"""
|
|
Generate a Gremlin query to match an edge by label and property.
|
|
|
|
Args:
|
|
edge_label: Edge label (e.g., 'RELATES_TO', 'MENTIONS')
|
|
property_name: Property name to match on
|
|
property_value_param: Parameter name for the property value
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
return f"g.E().hasLabel('{edge_label}').has('{property_name}', {property_value_param})"
|
|
|
|
|
|
def gremlin_get_outgoing_edges(
|
|
source_label: str,
|
|
edge_label: str,
|
|
target_label: str,
|
|
source_uuid_param: str = 'source_uuid',
|
|
) -> str:
|
|
"""
|
|
Generate a Gremlin query to get outgoing edges from a node.
|
|
|
|
Args:
|
|
source_label: Source node label
|
|
edge_label: Edge label
|
|
target_label: Target node label
|
|
source_uuid_param: Parameter name for source UUID
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
return (
|
|
f"g.V().hasLabel('{source_label}').has('uuid', {source_uuid_param})"
|
|
f".outE('{edge_label}').as('e')"
|
|
f".inV().hasLabel('{target_label}').as('target')"
|
|
f".select('e', 'target')"
|
|
)
|
|
|
|
|
|
def gremlin_bfs_traversal(
|
|
start_label: str,
|
|
edge_labels: list[str],
|
|
max_depth: int,
|
|
start_uuids_param: str = 'start_uuids',
|
|
) -> str:
|
|
"""
|
|
Generate a Gremlin query for breadth-first search traversal.
|
|
|
|
Args:
|
|
start_label: Starting node label
|
|
edge_labels: List of edge labels to traverse
|
|
max_depth: Maximum traversal depth
|
|
start_uuids_param: Parameter name for starting UUIDs
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
edge_labels_str = "', '".join(edge_labels)
|
|
return (
|
|
f"g.V().hasLabel('{start_label}').has('uuid', within({start_uuids_param}))"
|
|
f".repeat(bothE('{edge_labels_str}').otherV()).times({max_depth})"
|
|
f'.dedup()'
|
|
)
|
|
|
|
|
|
def gremlin_delete_all_nodes() -> str:
|
|
"""
|
|
Generate a Gremlin query to delete all nodes and edges.
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
return 'g.V().drop()'
|
|
|
|
|
|
def gremlin_delete_nodes_by_group_id(label: str, group_ids_param: str = 'group_ids') -> str:
|
|
"""
|
|
Generate a Gremlin query to delete nodes by group_id.
|
|
|
|
Args:
|
|
label: Node label
|
|
group_ids_param: Parameter name for group IDs list
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
return f"g.V().hasLabel('{label}').has('group_id', within({group_ids_param})).drop()"
|
|
|
|
|
|
def gremlin_cosine_similarity_filter(
|
|
embedding_property: str, search_vector_param: str, min_score: float
|
|
) -> str:
|
|
"""
|
|
Generate a Gremlin query fragment for cosine similarity filtering.
|
|
Note: This is a placeholder as Neptune Gremlin doesn't have built-in vector similarity.
|
|
Vector similarity should be handled via OpenSearch integration.
|
|
|
|
Args:
|
|
embedding_property: Property name containing the embedding
|
|
search_vector_param: Parameter name for search vector
|
|
min_score: Minimum similarity score
|
|
|
|
Returns:
|
|
Gremlin query fragment (warning comment)
|
|
"""
|
|
# Neptune Gremlin doesn't support vector similarity natively
|
|
# This should be handled via OpenSearch AOSS integration
|
|
return f"// Vector similarity for '{embedding_property}' must be handled via OpenSearch"
|
|
|
|
|
|
def gremlin_retrieve_episodes(
|
|
reference_time_param: str = 'reference_time',
|
|
group_ids_param: str = 'group_ids',
|
|
limit_param: str = 'num_episodes',
|
|
source_param: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Generate a Gremlin query to retrieve episodes filtered by time and optionally by group_id and source.
|
|
|
|
Args:
|
|
reference_time_param: Parameter name for reference timestamp
|
|
group_ids_param: Parameter name for group IDs list
|
|
limit_param: Parameter name for result limit
|
|
source_param: Optional parameter name for source filter
|
|
|
|
Returns:
|
|
Gremlin traversal string
|
|
"""
|
|
query = f"g.V().hasLabel('Episodic').has('valid_at', lte({reference_time_param}))"
|
|
|
|
# Add group_id filter if specified
|
|
query += f".has('group_id', within({group_ids_param}))"
|
|
|
|
# Add source filter if specified
|
|
if source_param:
|
|
query += f".has('source', {source_param})"
|
|
|
|
# Order by valid_at descending and limit
|
|
query += f".order().by('valid_at', desc).limit({limit_param}).valueMap(true)"
|
|
|
|
return query
|