update migration (#870)
* update migration * bump version * close driver
This commit is contained in:
parent
d62c203147
commit
309159bccb
5 changed files with 38 additions and 7 deletions
|
|
@ -250,4 +250,3 @@ async def main():
|
|||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
@ -1,4 +1,8 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.helpers import validate_group_id
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
|
||||
|
||||
|
|
@ -11,21 +15,21 @@ async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size
|
|||
MATCH (n:Episodic {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$group_label
|
||||
SET n:$($group_label)
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
entity_query = """
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$group_label
|
||||
SET n:$($group_label)
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
community_query = """
|
||||
MATCH (n:Community {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
SET n:$group_label
|
||||
SET n:$($group_label)
|
||||
} IN TRANSACTIONS OF $batch_size ROWS"""
|
||||
|
||||
async with driver.session() as session:
|
||||
|
|
@ -51,3 +55,31 @@ async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size
|
|||
group_label='Community_' + group_id.replace('-', ''),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
async def neo4j_node_label_migration(driver: GraphDriver):
|
||||
query = """MATCH (n:Episodic)
|
||||
RETURN DISTINCT n.group_id AS group_id"""
|
||||
|
||||
results, _, _ = await driver.execute_query(query)
|
||||
for result in results:
|
||||
group_id = result['group_id']
|
||||
await neo4j_node_group_labels(driver, group_id)
|
||||
|
||||
|
||||
async def main():
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
driver = Neo4jDriver(
|
||||
uri=neo4j_uri,
|
||||
user=neo4j_user,
|
||||
password=neo4j_password,
|
||||
)
|
||||
await neo4j_node_label_migration(driver)
|
||||
await driver.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
||||
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
||||
return
|
||||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.19.0pre2"
|
||||
version = "0.19.0pre3"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.19.0rc2"
|
||||
version = "0.19.0rc3"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue