From da6f3336bbcd6fa8c39e06ef42c4544415025a06 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Sun, 31 Aug 2025 13:19:29 -0400 Subject: [PATCH] update-tests (#872) * update-tests * unit test update * update tests * update tests * update kuzu query * update * update query * update args * fix bulk episode add * make handling better --- .../migrations/neo4j_node_group_labels.py | 37 ++++++++++++++++-- graphiti_core/nodes.py | 28 ++++++++------ graphiti_core/search/search_utils.py | 4 +- graphiti_core/utils/bulk_utils.py | 3 +- pyproject.toml | 2 +- tests/test_graphiti_mock.py | 38 ++++++++++++------- tests/test_node_int.py | 6 ++- uv.lock | 2 +- 8 files changed, 84 insertions(+), 36 deletions(-) diff --git a/graphiti_core/migrations/neo4j_node_group_labels.py b/graphiti_core/migrations/neo4j_node_group_labels.py index f075cef5..b9724980 100644 --- a/graphiti_core/migrations/neo4j_node_group_labels.py +++ b/graphiti_core/migrations/neo4j_node_group_labels.py @@ -1,4 +1,5 @@ import asyncio +import csv import os from graphiti_core.driver.driver import GraphDriver @@ -57,14 +58,41 @@ async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size ) -async def neo4j_node_label_migration(driver: GraphDriver): +def pop_last_n_group_ids(csv_file: str = 'group_ids.csv', count: int = 10): + with open(csv_file) as file: + reader = csv.reader(file) + group_ids = [row[0] for row in reader] + + total_count = len(group_ids) + popped = group_ids[-count:] + remaining = group_ids[:-count] + + with open(csv_file, 'w', newline='') as file: + writer = csv.writer(file) + for gid in remaining: + writer.writerow([gid]) + + return popped, total_count + + +async def get_group_ids(driver: GraphDriver): query = """MATCH (n:Episodic) RETURN DISTINCT n.group_id AS group_id""" results, _, _ = await driver.execute_query(query) - for result in results: - group_id = result['group_id'] - await neo4j_node_group_labels(driver, group_id) + group_ids = [result['group_id'] for result in results] + + with open('group_ids.csv', 'w', newline='') as file: + writer = csv.writer(file) + for gid in group_ids: + writer.writerow([gid]) + + +async def neo4j_node_label_migration(driver: GraphDriver, batch_size: int = 10): + group_ids, total = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size) + while len(group_ids) > 0: + await asyncio.gather(*[neo4j_node_group_labels(driver, group_id) for group_id in group_ids]) + group_ids, _ = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size) async def main(): @@ -77,6 +105,7 @@ async def main(): user=neo4j_user, password=neo4j_password, ) + await get_group_ids(driver) await neo4j_node_label_migration(driver) await driver.close() diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 4080fcc6..9558e711 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -286,18 +286,24 @@ class EpisodicNode(Node): } ], ) + + episode_args = { + 'uuid': self.uuid, + 'name': self.name, + 'group_id': self.group_id, + 'source_description': self.source_description, + 'content': self.content, + 'entity_edges': self.entity_edges, + 'created_at': self.created_at, + 'valid_at': self.valid_at, + 'source': self.source.value, + } + + if driver.provider == GraphProvider.NEO4J: + episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '') + result = await driver.execute_query( - get_episode_node_save_query(driver.provider), - uuid=self.uuid, - name=self.name, - group_id=self.group_id, - group_label='Episodic_' + self.group_id.replace('-', ''), - source_description=self.source_description, - content=self.content, - entity_edges=self.entity_edges, - created_at=self.created_at, - valid_at=self.valid_at, - source=self.source.value, + get_episode_node_save_query(driver.provider), **episode_args ) logger.debug(f'Saved Node to Graph: {self.uuid}') diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 6c61ab24..b13c3188 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -605,9 +605,7 @@ async def node_fulltext_search( + (group_ids[0].replace('-', '') if group_ids is not None else '') ) query = ( - get_nodes_query( - index_name, '$query', limit=limit, provider=driver.provider - ) + get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider) + yield_query + filter_query + """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 14be80a2..426cbe90 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -119,7 +119,8 @@ async def add_nodes_and_edges_bulk_tx( for episode in episodes: episode['source'] = str(episode['source'].value) episode.pop('labels', None) - episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '') + if driver.provider == GraphProvider.NEO4J: + episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '') nodes = [] diff --git a/pyproject.toml b/pyproject.toml index 4dd8db74..029b05a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.19.0pre3" +version = "0.19.0pre4" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/tests/test_graphiti_mock.py b/tests/test_graphiti_mock.py index 9426dc9f..9260c747 100644 --- a/tests/test_graphiti_mock.py +++ b/tests/test_graphiti_mock.py @@ -116,9 +116,7 @@ def mock_cross_encoder_client(): @pytest.mark.asyncio -async def test_add_bulk( - graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client -): +async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client): if graph_driver.provider == GraphProvider.FALKORDB: pytest.skip('Skipping as test fails on FalkorDB') @@ -143,7 +141,7 @@ async def test_add_bulk( source_description='conversation message', content='Alice likes Bob', valid_at=now, - entity_edges=[], # Filled in later + entity_edges=[], # Filled in later ) episode_node_2 = EpisodicNode( name='test_episode_2', @@ -154,14 +152,14 @@ async def test_add_bulk( source_description='conversation message', content='Bob adores Alice', valid_at=now, - entity_edges=[], # Filled in later + entity_edges=[], # Filled in later ) # Create entity nodes entity_node_1 = EntityNode( name='test_entity_1', group_id=group_id, - labels=['Entity', 'Person'], + labels=['Entity', 'Entity_graphiti_test_group', 'Person'], created_at=now, summary='test_entity_1 summary', attributes={'age': 30, 'location': 'New York'}, @@ -171,7 +169,7 @@ async def test_add_bulk( entity_node_2 = EntityNode( name='test_entity_2', group_id=group_id, - labels=['Entity', 'Person2'], + labels=['Entity', 'Entity_graphiti_test_group', 'Person2'], created_at=now, summary='test_entity_2 summary', attributes={'age': 25, 'location': 'Los Angeles'}, @@ -181,7 +179,7 @@ async def test_add_bulk( entity_node_3 = EntityNode( name='test_entity_3', group_id=group_id, - labels=['Entity', 'City', 'Location'], + labels=['Entity', 'City', 'Entity_graphiti_test_group', 'Location'], created_at=now, summary='test_entity_3 summary', attributes={'age': 25, 'location': 'Los Angeles'}, @@ -191,7 +189,7 @@ async def test_add_bulk( entity_node_4 = EntityNode( name='test_entity_4', group_id=group_id, - labels=['Entity'], + labels=['Entity', 'Entity_graphiti_test_group'], created_at=now, summary='test_entity_4 summary', attributes={'age': 25, 'location': 'Los Angeles'}, @@ -269,8 +267,22 @@ async def test_add_bulk( mock_embedder, ) - node_ids = [episode_node_1.uuid, episode_node_2.uuid, entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid] - edge_ids = [episodic_edge_1.uuid, episodic_edge_2.uuid, episodic_edge_3.uuid, episodic_edge_4.uuid, entity_edge_1.uuid, entity_edge_2.uuid] + node_ids = [ + episode_node_1.uuid, + episode_node_2.uuid, + entity_node_1.uuid, + entity_node_2.uuid, + entity_node_3.uuid, + entity_node_4.uuid, + ] + edge_ids = [ + episodic_edge_1.uuid, + episodic_edge_2.uuid, + episodic_edge_3.uuid, + episodic_edge_4.uuid, + entity_edge_1.uuid, + entity_edge_2.uuid, + ] node_count = await get_node_count(graph_driver, node_ids) assert node_count == len(node_ids) edge_count = await get_edge_count(graph_driver, edge_ids) @@ -290,7 +302,6 @@ async def test_add_bulk( retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid) await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2) - retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid) await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3) @@ -317,6 +328,7 @@ async def test_add_bulk( retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid) await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2) + @pytest.mark.asyncio async def test_remove_episode( graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client @@ -342,7 +354,7 @@ async def test_remove_episode( source_description='conversation message', content='Alice likes Bob', valid_at=now, - entity_edges=[], # Filled in later + entity_edges=[], # Filled in later ) # Create entity nodes diff --git a/tests/test_node_int.py b/tests/test_node_int.py index edaa017a..34226803 100644 --- a/tests/test_node_int.py +++ b/tests/test_node_int.py @@ -45,7 +45,7 @@ def sample_entity_node(): uuid=str(uuid4()), name='Test Entity', group_id=group_id, - labels=['Entity', 'Person'], + labels=['Entity', 'Entity_graphiti_test_group', 'Person'], created_at=created_at, name_embedding=[0.5] * 1024, summary='Entity Summary', @@ -103,7 +103,9 @@ async def test_entity_node(sample_entity_node, graph_driver): await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node) # Get node by group ids - retrieved = await EntityNode.get_by_group_ids(graph_driver, [group_id], limit=2, with_embeddings=True) + retrieved = await EntityNode.get_by_group_ids( + graph_driver, [group_id], limit=2, with_embeddings=True + ) assert len(retrieved) == 1 await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node) diff --git a/uv.lock b/uv.lock index 6a731d48..9feb7cf6 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.19.0rc3" +version = "0.19.0rc4" source = { editable = "." } dependencies = [ { name = "diskcache" },