diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index dc77d8f8..7fa3f169 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -121,10 +121,24 @@ class NeptuneDriver(GraphDriver): if not host: raise ValueError('You must provide an endpoint to create a NeptuneDriver') + # Define Graphiti schema to avoid expensive statistics API calls + graphiti_schema = """ +Node labels: Episodic, Entity, Community +Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER +Node properties: + Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list} + Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list} + Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string} +Relationship properties: + MENTIONS {created_at: datetime} + RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime} + HAS_MEMBER {uuid: string, created_at: datetime} +""" + if host.startswith('neptune-db://'): # This is a Neptune Database Cluster endpoint = host.replace('neptune-db://', '') - self.client = NeptuneGraph(endpoint, port) + self.client = NeptuneGraph(endpoint, port, schema=graphiti_schema) logger.debug('Creating Neptune Database session for %s', host) elif host.startswith('neptune-graph://'): # This is a Neptune Analytics Graph