diff --git a/examples/quickstart/quickstart_neptune.py b/examples/quickstart/quickstart_neptune.py index 76a494e3..3f2c95cf 100644 --- a/examples/quickstart/quickstart_neptune.py +++ b/examples/quickstart/quickstart_neptune.py @@ -250,4 +250,3 @@ async def main(): if __name__ == '__main__': asyncio.run(main()) - \ No newline at end of file diff --git a/graphiti_core/migrations/neo4j_node_group_labels.py b/graphiti_core/migrations/neo4j_node_group_labels.py index f4cdb467..f075cef5 100644 --- a/graphiti_core/migrations/neo4j_node_group_labels.py +++ b/graphiti_core/migrations/neo4j_node_group_labels.py @@ -1,4 +1,8 @@ +import asyncio +import os + from graphiti_core.driver.driver import GraphDriver +from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.helpers import validate_group_id from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes @@ -11,21 +15,21 @@ async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size MATCH (n:Episodic {group_id: $group_id}) CALL { WITH n - SET n:$group_label + SET n:$($group_label) } IN TRANSACTIONS OF $batch_size ROWS""" entity_query = """ MATCH (n:Entity {group_id: $group_id}) CALL { WITH n - SET n:$group_label + SET n:$($group_label) } IN TRANSACTIONS OF $batch_size ROWS""" community_query = """ MATCH (n:Community {group_id: $group_id}) CALL { WITH n - SET n:$group_label + SET n:$($group_label) } IN TRANSACTIONS OF $batch_size ROWS""" async with driver.session() as session: @@ -51,3 +55,31 @@ async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size group_label='Community_' + group_id.replace('-', ''), batch_size=batch_size, ) + + +async def neo4j_node_label_migration(driver: GraphDriver): + query = """MATCH (n:Episodic) + RETURN DISTINCT n.group_id AS group_id""" + + results, _, _ = await driver.execute_query(query) + for result in results: + group_id = result['group_id'] + await neo4j_node_group_labels(driver, group_id) + + +async def main(): + neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' + neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' + neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' + + driver = Neo4jDriver( + uri=neo4j_uri, + user=neo4j_user, + password=neo4j_password, + ) + await neo4j_node_label_migration(driver) + await driver.close() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 8de54d8e..4cf97784 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): if driver.provider == GraphProvider.NEPTUNE: - await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue] + await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue] return if delete_existing: records, _, _ = await driver.execute_query( diff --git a/pyproject.toml b/pyproject.toml index f286d446..0301f20c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.19.0pre2" +version = "0.19.0pre3" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index b1b92aa0..fecce3bc 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.19.0rc2" +version = "0.19.0rc3" source = { editable = "." } dependencies = [ { name = "diskcache" },