graphiti/graphiti_core/graph_queries.py
supmo668 a944871942 feat: Add Gremlin query language support for Neptune Database
Adds experimental support for Apache TinkerPop Gremlin as an alternative
query language for AWS Neptune Database, alongside the existing openCypher
support. This enables users to choose their preferred query language and
opens the door for future support of other Gremlin-compatible databases.

- QueryLanguage enum (CYPHER, GREMLIN) for explicit language selection
- Dual-mode NeptuneDriver supporting both Cypher and Gremlin
- Gremlin query generation functions for common graph operations
- Graceful degradation when gremlinpython is not installed
- 100% backward compatible (defaults to CYPHER)

- graphiti_core/driver/driver.py: Added QueryLanguage enum
- graphiti_core/driver/neptune_driver.py: Dual client initialization
  and query routing based on language selection
- graphiti_core/graph_queries.py: 9 new Gremlin query generation functions

- graphiti_core/utils/maintenance/graph_data_operations.py: Updated
  clear_data() to support both query languages

- tests/test_neptune_gremlin_int.py: Comprehensive integration tests
- examples/quickstart/quickstart_neptune_gremlin.py: Usage example
- examples/quickstart/README.md: Updated with Gremlin instructions
- GREMLIN_FEATURE.md: Complete feature documentation

- pyproject.toml: Added gremlinpython>=3.7.0 to neptune extras

```python
from graphiti_core.driver.driver import QueryLanguage
from graphiti_core.driver.neptune_driver import NeptuneDriver

driver = NeptuneDriver(
    host='neptune-db://cluster.amazonaws.com',
    aoss_host='aoss-cluster.amazonaws.com',
    query_language=QueryLanguage.GREMLIN
)
```

- Only Neptune Database supports Gremlin (not Neptune Analytics)
- Fulltext and vector search still use OpenSearch (AOSS) integration
- Complete search_utils.py Gremlin implementation pending (future work)

-  All existing unit tests pass (103/103)
-  New integration tests for Gremlin operations
-  Type checking passes
-  Linting passes

None. Fully backward compatible.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 23:45:59 -08:00

343 lines
13 KiB
Python

"""
Database query utilities for different graph database backends.
This module provides database-agnostic query generation for Neo4j, FalkorDB, Kuzu, and Neptune,
supporting index creation, fulltext search, bulk operations, and Gremlin queries.
"""
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphProvider
# Mapping from Neo4j fulltext index names to FalkorDB node labels
NEO4J_TO_FALKORDB_MAPPING = {
'node_name_and_summary': 'Entity',
'community_name': 'Community',
'episode_content': 'Episodic',
'edge_name_and_fact': 'RELATES_TO',
}
# Mapping from fulltext index names to Kuzu node labels
INDEX_TO_LABEL_KUZU_MAPPING = {
'node_name_and_summary': 'Entity',
'community_name': 'Community',
'episode_content': 'Episodic',
'edge_name_and_fact': 'RelatesToNode_',
}
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
if provider == GraphProvider.FALKORDB:
return [
# Entity node
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
# Episodic node
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
# Community node
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
# RELATES_TO edge
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
# MENTIONS edge
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
# HAS_MEMBER edge
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
]
if provider == GraphProvider.KUZU:
return []
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX community_group_id IF NOT EXISTS FOR (n:Community) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
]
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
if provider == GraphProvider.FALKORDB:
from typing import cast
from graphiti_core.driver.falkordb_driver import STOPWORDS
# Convert to string representation for embedding in queries
stopwords_str = str(STOPWORDS)
# Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
return cast(
list[LiteralString],
[
f"""CALL db.idx.fulltext.createNodeIndex(
{{
label: 'Episodic',
stopwords: {stopwords_str}
}},
'content', 'source', 'source_description', 'group_id'
)""",
f"""CALL db.idx.fulltext.createNodeIndex(
{{
label: 'Entity',
stopwords: {stopwords_str}
}},
'name', 'summary', 'group_id'
)""",
f"""CALL db.idx.fulltext.createNodeIndex(
{{
label: 'Community',
stopwords: {stopwords_str}
}},
'name', 'group_id'
)""",
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
],
)
if provider == GraphProvider.KUZU:
return [
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
]
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
FOR (n:Community) ON EACH [n.name, n.group_id]""",
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
if provider == GraphProvider.KUZU:
return f'array_cosine_similarity({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
# Gremlin Query Generation Functions
def gremlin_match_node_by_property(
label: str, property_name: str, property_value_param: str
) -> str:
"""
Generate a Gremlin query to match a node by label and property.
Args:
label: Node label (e.g., 'Entity', 'Episodic')
property_name: Property name to match on
property_value_param: Parameter name for the property value
Returns:
Gremlin traversal string
"""
return f"g.V().hasLabel('{label}').has('{property_name}', {property_value_param})"
def gremlin_match_nodes_by_uuids(label: str, uuids_param: str = 'uuids') -> str:
"""
Generate a Gremlin query to match multiple nodes by UUIDs.
Args:
label: Node label (e.g., 'Entity', 'Episodic')
uuids_param: Parameter name containing list of UUIDs
Returns:
Gremlin traversal string
"""
return f"g.V().hasLabel('{label}').has('uuid', within({uuids_param}))"
def gremlin_match_edge_by_property(
edge_label: str, property_name: str, property_value_param: str
) -> str:
"""
Generate a Gremlin query to match an edge by label and property.
Args:
edge_label: Edge label (e.g., 'RELATES_TO', 'MENTIONS')
property_name: Property name to match on
property_value_param: Parameter name for the property value
Returns:
Gremlin traversal string
"""
return f"g.E().hasLabel('{edge_label}').has('{property_name}', {property_value_param})"
def gremlin_get_outgoing_edges(
source_label: str,
edge_label: str,
target_label: str,
source_uuid_param: str = 'source_uuid',
) -> str:
"""
Generate a Gremlin query to get outgoing edges from a node.
Args:
source_label: Source node label
edge_label: Edge label
target_label: Target node label
source_uuid_param: Parameter name for source UUID
Returns:
Gremlin traversal string
"""
return (
f"g.V().hasLabel('{source_label}').has('uuid', {source_uuid_param})"
f".outE('{edge_label}').as('e')"
f".inV().hasLabel('{target_label}').as('target')"
f".select('e', 'target')"
)
def gremlin_bfs_traversal(
start_label: str,
edge_labels: list[str],
max_depth: int,
start_uuids_param: str = 'start_uuids',
) -> str:
"""
Generate a Gremlin query for breadth-first search traversal.
Args:
start_label: Starting node label
edge_labels: List of edge labels to traverse
max_depth: Maximum traversal depth
start_uuids_param: Parameter name for starting UUIDs
Returns:
Gremlin traversal string
"""
edge_labels_str = "', '".join(edge_labels)
return (
f"g.V().hasLabel('{start_label}').has('uuid', within({start_uuids_param}))"
f".repeat(bothE('{edge_labels_str}').otherV()).times({max_depth})"
f'.dedup()'
)
def gremlin_delete_all_nodes() -> str:
"""
Generate a Gremlin query to delete all nodes and edges.
Returns:
Gremlin traversal string
"""
return 'g.V().drop()'
def gremlin_delete_nodes_by_group_id(label: str, group_ids_param: str = 'group_ids') -> str:
"""
Generate a Gremlin query to delete nodes by group_id.
Args:
label: Node label
group_ids_param: Parameter name for group IDs list
Returns:
Gremlin traversal string
"""
return f"g.V().hasLabel('{label}').has('group_id', within({group_ids_param})).drop()"
def gremlin_cosine_similarity_filter(
embedding_property: str, search_vector_param: str, min_score: float
) -> str:
"""
Generate a Gremlin query fragment for cosine similarity filtering.
Note: This is a placeholder as Neptune Gremlin doesn't have built-in vector similarity.
Vector similarity should be handled via OpenSearch integration.
Args:
embedding_property: Property name containing the embedding
search_vector_param: Parameter name for search vector
min_score: Minimum similarity score
Returns:
Gremlin query fragment (warning comment)
"""
# Neptune Gremlin doesn't support vector similarity natively
# This should be handled via OpenSearch AOSS integration
return f"// Vector similarity for '{embedding_property}' must be handled via OpenSearch"
def gremlin_retrieve_episodes(
reference_time_param: str = 'reference_time',
group_ids_param: str = 'group_ids',
limit_param: str = 'num_episodes',
source_param: str | None = None,
) -> str:
"""
Generate a Gremlin query to retrieve episodes filtered by time and optionally by group_id and source.
Args:
reference_time_param: Parameter name for reference timestamp
group_ids_param: Parameter name for group IDs list
limit_param: Parameter name for result limit
source_param: Optional parameter name for source filter
Returns:
Gremlin traversal string
"""
query = f"g.V().hasLabel('Episodic').has('valid_at', lte({reference_time_param}))"
# Add group_id filter if specified
query += f".has('group_id', within({group_ids_param}))"
# Add source filter if specified
if source_param:
query += f".has('source', {source_param})"
# Order by valid_at descending and limit
query += f".order().by('valid_at', desc).limit({limit_param}).valueMap(true)"
return query