update-tests (#872)

* update-tests

* unit test update

* update tests

* update tests

* update kuzu query

* update

* update query

* update args

* fix bulk episode add

* make handling better
This commit is contained in:
Preston Rasmussen 2025-08-31 13:19:29 -04:00 committed by GitHub
parent 119a43b8e4
commit da6f3336bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 84 additions and 36 deletions

View file

@ -1,4 +1,5 @@
import asyncio
import csv
import os
from graphiti_core.driver.driver import GraphDriver
@ -57,14 +58,41 @@ async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size
)
async def neo4j_node_label_migration(driver: GraphDriver):
def pop_last_n_group_ids(csv_file: str = 'group_ids.csv', count: int = 10):
with open(csv_file) as file:
reader = csv.reader(file)
group_ids = [row[0] for row in reader]
total_count = len(group_ids)
popped = group_ids[-count:]
remaining = group_ids[:-count]
with open(csv_file, 'w', newline='') as file:
writer = csv.writer(file)
for gid in remaining:
writer.writerow([gid])
return popped, total_count
async def get_group_ids(driver: GraphDriver):
query = """MATCH (n:Episodic)
RETURN DISTINCT n.group_id AS group_id"""
results, _, _ = await driver.execute_query(query)
for result in results:
group_id = result['group_id']
await neo4j_node_group_labels(driver, group_id)
group_ids = [result['group_id'] for result in results]
with open('group_ids.csv', 'w', newline='') as file:
writer = csv.writer(file)
for gid in group_ids:
writer.writerow([gid])
async def neo4j_node_label_migration(driver: GraphDriver, batch_size: int = 10):
group_ids, total = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
while len(group_ids) > 0:
await asyncio.gather(*[neo4j_node_group_labels(driver, group_id) for group_id in group_ids])
group_ids, _ = pop_last_n_group_ids(csv_file='group_ids.csv', count=batch_size)
async def main():
@ -77,6 +105,7 @@ async def main():
user=neo4j_user,
password=neo4j_password,
)
await get_group_ids(driver)
await neo4j_node_label_migration(driver)
await driver.close()

View file

@ -286,18 +286,24 @@ class EpisodicNode(Node):
}
],
)
episode_args = {
'uuid': self.uuid,
'name': self.name,
'group_id': self.group_id,
'source_description': self.source_description,
'content': self.content,
'entity_edges': self.entity_edges,
'created_at': self.created_at,
'valid_at': self.valid_at,
'source': self.source.value,
}
if driver.provider == GraphProvider.NEO4J:
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
result = await driver.execute_query(
get_episode_node_save_query(driver.provider),
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
group_label='Episodic_' + self.group_id.replace('-', ''),
source_description=self.source_description,
content=self.content,
entity_edges=self.entity_edges,
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source.value,
get_episode_node_save_query(driver.provider), **episode_args
)
logger.debug(f'Saved Node to Graph: {self.uuid}')

View file

@ -605,9 +605,7 @@ async def node_fulltext_search(
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
query = (
get_nodes_query(
index_name, '$query', limit=limit, provider=driver.provider
)
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
+ yield_query
+ filter_query
+ """

View file

@ -119,7 +119,8 @@ async def add_nodes_and_edges_bulk_tx(
for episode in episodes:
episode['source'] = str(episode['source'].value)
episode.pop('labels', None)
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
if driver.provider == GraphProvider.NEO4J:
episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
nodes = []

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.19.0pre3"
version = "0.19.0pre4"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -116,9 +116,7 @@ def mock_cross_encoder_client():
@pytest.mark.asyncio
async def test_add_bulk(
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
):
async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
@ -143,7 +141,7 @@ async def test_add_bulk(
source_description='conversation message',
content='Alice likes Bob',
valid_at=now,
entity_edges=[], # Filled in later
entity_edges=[], # Filled in later
)
episode_node_2 = EpisodicNode(
name='test_episode_2',
@ -154,14 +152,14 @@ async def test_add_bulk(
source_description='conversation message',
content='Bob adores Alice',
valid_at=now,
entity_edges=[], # Filled in later
entity_edges=[], # Filled in later
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
group_id=group_id,
labels=['Entity', 'Person'],
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
created_at=now,
summary='test_entity_1 summary',
attributes={'age': 30, 'location': 'New York'},
@ -171,7 +169,7 @@ async def test_add_bulk(
entity_node_2 = EntityNode(
name='test_entity_2',
group_id=group_id,
labels=['Entity', 'Person2'],
labels=['Entity', 'Entity_graphiti_test_group', 'Person2'],
created_at=now,
summary='test_entity_2 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
@ -181,7 +179,7 @@ async def test_add_bulk(
entity_node_3 = EntityNode(
name='test_entity_3',
group_id=group_id,
labels=['Entity', 'City', 'Location'],
labels=['Entity', 'City', 'Entity_graphiti_test_group', 'Location'],
created_at=now,
summary='test_entity_3 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
@ -191,7 +189,7 @@ async def test_add_bulk(
entity_node_4 = EntityNode(
name='test_entity_4',
group_id=group_id,
labels=['Entity'],
labels=['Entity', 'Entity_graphiti_test_group'],
created_at=now,
summary='test_entity_4 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
@ -269,8 +267,22 @@ async def test_add_bulk(
mock_embedder,
)
node_ids = [episode_node_1.uuid, episode_node_2.uuid, entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
edge_ids = [episodic_edge_1.uuid, episodic_edge_2.uuid, episodic_edge_3.uuid, episodic_edge_4.uuid, entity_edge_1.uuid, entity_edge_2.uuid]
node_ids = [
episode_node_1.uuid,
episode_node_2.uuid,
entity_node_1.uuid,
entity_node_2.uuid,
entity_node_3.uuid,
entity_node_4.uuid,
]
edge_ids = [
episodic_edge_1.uuid,
episodic_edge_2.uuid,
episodic_edge_3.uuid,
episodic_edge_4.uuid,
entity_edge_1.uuid,
entity_edge_2.uuid,
]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == len(node_ids)
edge_count = await get_edge_count(graph_driver, edge_ids)
@ -290,7 +302,6 @@ async def test_add_bulk(
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)
@ -317,6 +328,7 @@ async def test_add_bulk(
retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)
@pytest.mark.asyncio
async def test_remove_episode(
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
@ -342,7 +354,7 @@ async def test_remove_episode(
source_description='conversation message',
content='Alice likes Bob',
valid_at=now,
entity_edges=[], # Filled in later
entity_edges=[], # Filled in later
)
# Create entity nodes

View file

@ -45,7 +45,7 @@ def sample_entity_node():
uuid=str(uuid4()),
name='Test Entity',
group_id=group_id,
labels=['Entity', 'Person'],
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Entity Summary',
@ -103,7 +103,9 @@ async def test_entity_node(sample_entity_node, graph_driver):
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
# Get node by group ids
retrieved = await EntityNode.get_by_group_ids(graph_driver, [group_id], limit=2, with_embeddings=True)
retrieved = await EntityNode.get_by_group_ids(
graph_driver, [group_id], limit=2, with_embeddings=True
)
assert len(retrieved) == 1
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.19.0rc3"
version = "0.19.0rc4"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },